Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9862567master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:7b8f50a0537bfe7e082c5ad91b2b7ece61a0adbeb7489988e553909276bf920c | |||
| size 44217644 | |||
| @@ -20,6 +20,7 @@ class Models(object): | |||
| gpen = 'gpen' | |||
| product_retrieval_embedding = 'product-retrieval-embedding' | |||
| body_2d_keypoints = 'body-2d-keypoints' | |||
| body_3d_keypoints = 'body-3d-keypoints' | |||
| crowd_counting = 'HRNetCrowdCounting' | |||
| panoptic_segmentation = 'swinL-panoptic-segmentation' | |||
| image_reid_person = 'passvitb' | |||
| @@ -95,6 +96,7 @@ class Pipelines(object): | |||
| general_recognition = 'resnet101-general-recognition' | |||
| cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | |||
| body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image' | |||
| body_3d_keypoints = 'canonical_body-3d-keypoints_video' | |||
| human_detection = 'resnet18-human-detection' | |||
| object_detection = 'vit-object-detection' | |||
| easycv_detection = 'easycv-detection' | |||
| @@ -1,11 +1,16 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| # yapf: disable | |||
| from . import (action_recognition, animal_recognition, body_2d_keypoints, | |||
| cartoon, cmdssl_video_embedding, crowd_counting, face_detection, | |||
| face_generation, image_classification, image_color_enhance, | |||
| image_colorization, image_denoise, image_instance_segmentation, | |||
| body_3d_keypoints, cartoon, cmdssl_video_embedding, | |||
| crowd_counting, face_detection, face_generation, | |||
| image_classification, image_color_enhance, image_colorization, | |||
| image_denoise, image_instance_segmentation, | |||
| image_panoptic_segmentation, image_portrait_enhancement, | |||
| image_reid_person, image_semantic_segmentation, | |||
| image_to_image_generation, image_to_image_translation, | |||
| object_detection, product_retrieval_embedding, | |||
| realtime_object_detection, salient_detection, super_resolution, | |||
| video_single_object_tracking, video_summarization, virual_tryon) | |||
| # yapf: enable | |||
| @@ -0,0 +1,23 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .body_3d_pose import BodyKeypointsDetection3D | |||
| else: | |||
| _import_structure = { | |||
| 'body_3d_pose': ['BodyKeypointsDetection3D'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,246 @@ | |||
| import logging | |||
| import os.path as osp | |||
| from typing import Any, Dict, List, Union | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base.base_torch_model import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.models.cv.body_3d_keypoints.canonical_pose_modules import ( | |||
| TemporalModel, TransCan3Dkeys) | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| __all__ = ['BodyKeypointsDetection3D'] | |||
| class KeypointsTypes(object): | |||
| POSES_CAMERA = 'poses_camera' | |||
| POSES_TRAJ = 'poses_traj' | |||
| @MODELS.register_module( | |||
| Tasks.body_3d_keypoints, module_name=Models.body_3d_keypoints) | |||
| class BodyKeypointsDetection3D(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| self.model_dir = model_dir | |||
| model_path = osp.join(self.model_dir, ModelFile.TORCH_MODEL_FILE) | |||
| cfg_path = osp.join(self.model_dir, ModelFile.CONFIGURATION) | |||
| self.cfg = Config.from_file(cfg_path) | |||
| self._create_model() | |||
| if not osp.exists(model_path): | |||
| raise IOError(f'{model_path} is not exists.') | |||
| if torch.cuda.is_available(): | |||
| self._device = torch.device('cuda') | |||
| else: | |||
| self._device = torch.device('cpu') | |||
| self.pretrained_state_dict = torch.load( | |||
| model_path, map_location=self._device) | |||
| self.load_pretrained() | |||
| self.to_device(self._device) | |||
| self.eval() | |||
| def _create_model(self): | |||
| self.model_pos = TemporalModel( | |||
| self.cfg.model.MODEL.IN_NUM_JOINTS, | |||
| self.cfg.model.MODEL.IN_2D_FEATURE, | |||
| self.cfg.model.MODEL.OUT_NUM_JOINTS, | |||
| filter_widths=self.cfg.model.MODEL.FILTER_WIDTHS, | |||
| causal=self.cfg.model.MODEL.CAUSAL, | |||
| dropout=self.cfg.model.MODEL.DROPOUT, | |||
| channels=self.cfg.model.MODEL.CHANNELS, | |||
| dense=self.cfg.model.MODEL.DENSE) | |||
| receptive_field = self.model_pos.receptive_field() | |||
| self.pad = (receptive_field - 1) // 2 | |||
| if self.cfg.model.MODEL.CAUSAL: | |||
| self.causal_shift = self.pad | |||
| else: | |||
| self.causal_shift = 0 | |||
| self.model_traj = TransCan3Dkeys( | |||
| in_channels=self.cfg.model.MODEL.IN_NUM_JOINTS | |||
| * self.cfg.model.MODEL.IN_2D_FEATURE, | |||
| num_features=1024, | |||
| out_channels=self.cfg.model.MODEL.OUT_3D_FEATURE, | |||
| num_blocks=4, | |||
| time_window=receptive_field) | |||
| def eval(self): | |||
| self.model_pos.eval() | |||
| self.model_traj.eval() | |||
| def train(self): | |||
| self.model_pos.train() | |||
| self.model_traj.train() | |||
| def to_device(self, device): | |||
| self.model_pos = self.model_pos.to(device) | |||
| self.model_traj = self.model_traj.to(device) | |||
| def load_pretrained(self): | |||
| if 'model_pos' in self.pretrained_state_dict: | |||
| self.model_pos.load_state_dict( | |||
| self.pretrained_state_dict['model_pos'], strict=False) | |||
| else: | |||
| logging.error( | |||
| 'Not load model pos from pretrained_state_dict, not in pretrained_state_dict' | |||
| ) | |||
| if 'model_traj' in self.pretrained_state_dict: | |||
| self.model_traj.load_state_dict( | |||
| self.pretrained_state_dict['model_traj'], strict=False) | |||
| else: | |||
| logging.error( | |||
| 'Not load model traj from pretrained_state_dict, not in pretrained_state_dict' | |||
| ) | |||
| logging.info('Load pretrained model done.') | |||
| def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| """Proprocess of 2D input joints. | |||
| Args: | |||
| input (Dict[str, Any]): [NUM_FRAME, NUM_JOINTS, 2], input 2d human body keypoints. | |||
| Returns: | |||
| Dict[str, Any]: canonical 2d points and root relative joints. | |||
| """ | |||
| if 'cuda' == input.device.type: | |||
| input = input.data.cpu().numpy() | |||
| elif 'cpu' == input.device.type: | |||
| input = input.data.numpy() | |||
| pose2d = input | |||
| pose2d_canonical = self.canonicalize_2Ds( | |||
| pose2d, self.cfg.model.INPUT.FOCAL_LENGTH, | |||
| self.cfg.model.INPUT.CENTER) | |||
| pose2d_normalized = self.normalize_screen_coordinates( | |||
| pose2d, self.cfg.model.INPUT.RES_W, self.cfg.model.INPUT.RES_H) | |||
| pose2d_rr = pose2d_normalized | |||
| pose2d_rr[:, 1:] -= pose2d_rr[:, :1] | |||
| # expand [NUM_FRAME, NUM_JOINTS, 2] to [1, NUM_FRAME, NUM_JOINTS, 2] | |||
| pose2d_rr = np.expand_dims( | |||
| np.pad( | |||
| pose2d_rr, | |||
| ((self.pad + self.causal_shift, self.pad - self.causal_shift), | |||
| (0, 0), (0, 0)), 'edge'), | |||
| axis=0) | |||
| pose2d_canonical = np.expand_dims( | |||
| np.pad( | |||
| pose2d_canonical, | |||
| ((self.pad + self.causal_shift, self.pad - self.causal_shift), | |||
| (0, 0), (0, 0)), 'edge'), | |||
| axis=0) | |||
| pose2d_rr = torch.from_numpy(pose2d_rr.astype(np.float32)) | |||
| pose2d_canonical = torch.from_numpy( | |||
| pose2d_canonical.astype(np.float32)) | |||
| inputs_2d = pose2d_rr.clone() | |||
| if torch.cuda.is_available(): | |||
| inputs_2d = inputs_2d.cuda(non_blocking=True) | |||
| # Positional model | |||
| if self.cfg.model.MODEL.USE_2D_OFFSETS: | |||
| inputs_2d[:, :, 0] = 0 | |||
| else: | |||
| inputs_2d[:, :, 1:] += inputs_2d[:, :, :1] | |||
| return { | |||
| 'inputs_2d': inputs_2d, | |||
| 'pose2d_rr': pose2d_rr, | |||
| 'pose2d_canonical': pose2d_canonical | |||
| } | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| """3D human pose estimation. | |||
| Args: | |||
| input (Dict): | |||
| inputs_2d: [1, NUM_FRAME, NUM_JOINTS, 2] | |||
| pose2d_rr: [1, NUM_FRAME, NUM_JOINTS, 2] | |||
| pose2d_canonical: [1, NUM_FRAME, NUM_JOINTS, 2] | |||
| NUM_FRAME = max(receptive_filed + video_frame_number, video_frame_number) | |||
| Returns: | |||
| Dict[str, Any]: | |||
| "camera_pose": Tensor, [1, NUM_FRAME, OUT_NUM_JOINTS, OUT_3D_FEATURE_DIM], | |||
| 3D human pose keypoints in camera frame. | |||
| "camera_traj": Tensor, [1, NUM_FRAME, 1, 3], | |||
| root keypoints coordinates in camere frame. | |||
| """ | |||
| inputs_2d = input['inputs_2d'] | |||
| pose2d_rr = input['pose2d_rr'] | |||
| pose2d_canonical = input['pose2d_canonical'] | |||
| with torch.no_grad(): | |||
| # predict 3D pose keypoints | |||
| predicted_3d_pos = self.model_pos(inputs_2d) | |||
| # predict global trajectory | |||
| b1, w1, n1, d1 = inputs_2d.shape | |||
| input_pose2d_abs = self.get_abs_2d_pts(w1, pose2d_rr, | |||
| pose2d_canonical) | |||
| b1, w1, n1, d1 = input_pose2d_abs.size() | |||
| b2, w2, n2, d2 = predicted_3d_pos.size() | |||
| if torch.cuda.is_available(): | |||
| input_pose2d_abs = input_pose2d_abs.cuda(non_blocking=True) | |||
| predicted_3d_traj = self.model_traj( | |||
| input_pose2d_abs.view(b1, w1, n1 * d1), | |||
| predicted_3d_pos.view(b2 * w2, n2 * d2)).view(b2, w2, -1, 3) | |||
| predict_dict = { | |||
| KeypointsTypes.POSES_CAMERA: predicted_3d_pos, | |||
| KeypointsTypes.POSES_TRAJ: predicted_3d_traj | |||
| } | |||
| return predict_dict | |||
| def get_abs_2d_pts(self, input_video_frame_num, pose2d_rr, | |||
| pose2d_canonical): | |||
| pad = self.pad | |||
| w = input_video_frame_num - pad * 2 | |||
| lst_pose2d_rr = [] | |||
| lst_pose2d_cannoical = [] | |||
| for i in range(pad, w + pad): | |||
| lst_pose2d_rr.append(pose2d_rr[:, i - pad:i + pad + 1]) | |||
| lst_pose2d_cannoical.append(pose2d_canonical[:, | |||
| i - pad:i + pad + 1]) | |||
| input_pose2d_rr = torch.concat(lst_pose2d_cannoical, axis=0) | |||
| input_pose2d_cannoical = torch.concat(lst_pose2d_cannoical, axis=0) | |||
| if self.cfg.model.MODEL.USE_CANONICAL_COORDS: | |||
| input_pose2d_abs = input_pose2d_cannoical.clone() | |||
| else: | |||
| input_pose2d_abs = input_pose2d_rr.clone() | |||
| input_pose2d_abs[:, :, 1:] += input_pose2d_abs[:, :, :1] | |||
| return input_pose2d_abs | |||
| def canonicalize_2Ds(self, pos2d, f, c): | |||
| cs = np.array([c[0], c[1]]).reshape(1, 1, 2) | |||
| fs = np.array([f[0], f[1]]).reshape(1, 1, 2) | |||
| canoical_2Ds = (pos2d - cs) / fs | |||
| return canoical_2Ds | |||
| def normalize_screen_coordinates(self, X, w, h): | |||
| assert X.shape[-1] == 2 | |||
| # Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio | |||
| return X / w * 2 - [1, h / w] | |||
| @@ -0,0 +1,233 @@ | |||
| # The implementation is based on OSTrack, available at https://github.com/facebookresearch/VideoPose3D | |||
| import torch | |||
| import torch.nn as nn | |||
| class TemporalModelBase(nn.Module): | |||
| """ | |||
| Do not instantiate this class. | |||
| """ | |||
| def __init__(self, num_joints_in, in_features, num_joints_out, | |||
| filter_widths, causal, dropout, channels): | |||
| super().__init__() | |||
| # Validate input | |||
| for fw in filter_widths: | |||
| assert fw % 2 != 0, 'Only odd filter widths are supported' | |||
| self.num_joints_in = num_joints_in | |||
| self.in_features = in_features | |||
| self.num_joints_out = num_joints_out | |||
| self.filter_widths = filter_widths | |||
| self.drop = nn.Dropout(dropout) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.pad = [filter_widths[0] // 2] | |||
| self.expand_bn = nn.BatchNorm1d(channels, momentum=0.1) | |||
| self.shrink = nn.Conv1d(channels, num_joints_out * 3, 1) | |||
| def set_bn_momentum(self, momentum): | |||
| self.expand_bn.momentum = momentum | |||
| for bn in self.layers_bn: | |||
| bn.momentum = momentum | |||
| def receptive_field(self): | |||
| """ | |||
| Return the total receptive field of this model as # of frames. | |||
| """ | |||
| frames = 0 | |||
| for f in self.pad: | |||
| frames += f | |||
| return 1 + 2 * frames | |||
| def total_causal_shift(self): | |||
| """ | |||
| Return the asymmetric offset for sequence padding. | |||
| The returned value is typically 0 if causal convolutions are disabled, | |||
| otherwise it is half the receptive field. | |||
| """ | |||
| frames = self.causal_shift[0] | |||
| next_dilation = self.filter_widths[0] | |||
| for i in range(1, len(self.filter_widths)): | |||
| frames += self.causal_shift[i] * next_dilation | |||
| next_dilation *= self.filter_widths[i] | |||
| return frames | |||
| def forward(self, x): | |||
| assert len(x.shape) == 4 | |||
| assert x.shape[-2] == self.num_joints_in | |||
| assert x.shape[-1] == self.in_features | |||
| sz = x.shape[:3] | |||
| x = x.view(x.shape[0], x.shape[1], -1) | |||
| x = x.permute(0, 2, 1) | |||
| x = self._forward_blocks(x) | |||
| x = x.permute(0, 2, 1) | |||
| x = x.view(sz[0], -1, self.num_joints_out, 3) | |||
| return x | |||
| class TemporalModel(TemporalModelBase): | |||
| """ | |||
| Reference 3D pose estimation model with temporal convolutions. | |||
| This implementation can be used for all use-cases. | |||
| """ | |||
| def __init__(self, | |||
| num_joints_in, | |||
| in_features, | |||
| num_joints_out, | |||
| filter_widths, | |||
| causal=False, | |||
| dropout=0.25, | |||
| channels=1024, | |||
| dense=False): | |||
| """ | |||
| Initialize this model. | |||
| Arguments: | |||
| num_joints_in -- number of input joints (e.g. 17 for Human3.6M) | |||
| in_features -- number of input features for each joint (typically 2 for 2D input) | |||
| num_joints_out -- number of output joints (can be different than input) | |||
| filter_widths -- list of convolution widths, which also determines the # of blocks and receptive field | |||
| causal -- use causal convolutions instead of symmetric convolutions (for real-time applications) | |||
| dropout -- dropout probability | |||
| channels -- number of convolution channels | |||
| dense -- use regular dense convolutions instead of dilated convolutions (ablation experiment) | |||
| """ | |||
| super().__init__(num_joints_in, in_features, num_joints_out, | |||
| filter_widths, causal, dropout, channels) | |||
| self.expand_conv = nn.Conv1d( | |||
| num_joints_in * in_features, | |||
| channels, | |||
| filter_widths[0], | |||
| bias=False) | |||
| layers_conv = [] | |||
| layers_bn = [] | |||
| self.causal_shift = [(filter_widths[0]) // 2 if causal else 0] | |||
| next_dilation = filter_widths[0] | |||
| for i in range(1, len(filter_widths)): | |||
| self.pad.append((filter_widths[i] - 1) * next_dilation // 2) | |||
| self.causal_shift.append((filter_widths[i] // 2 | |||
| * next_dilation) if causal else 0) | |||
| layers_conv.append( | |||
| nn.Conv1d( | |||
| channels, | |||
| channels, | |||
| filter_widths[i] if not dense else (2 * self.pad[-1] + 1), | |||
| dilation=next_dilation if not dense else 1, | |||
| bias=False)) | |||
| layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) | |||
| layers_conv.append( | |||
| nn.Conv1d(channels, channels, 1, dilation=1, bias=False)) | |||
| layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1)) | |||
| next_dilation *= filter_widths[i] | |||
| self.layers_conv = nn.ModuleList(layers_conv) | |||
| self.layers_bn = nn.ModuleList(layers_bn) | |||
| def _forward_blocks(self, x): | |||
| x = self.drop(self.relu(self.expand_bn(self.expand_conv(x)))) | |||
| for i in range(len(self.pad) - 1): | |||
| pad = self.pad[i + 1] | |||
| shift = self.causal_shift[i + 1] | |||
| res = x[:, :, pad + shift:x.shape[2] - pad + shift] | |||
| x = self.drop( | |||
| self.relu(self.layers_bn[2 * i](self.layers_conv[2 * i](x)))) | |||
| x = res + self.drop( | |||
| self.relu(self.layers_bn[2 * i + 1]( | |||
| self.layers_conv[2 * i + 1](x)))) | |||
| x = self.shrink(x) | |||
| return x | |||
| # regression of the trajectory | |||
| class TransCan3Dkeys(nn.Module): | |||
| def __init__(self, | |||
| in_channels=74, | |||
| num_features=256, | |||
| out_channels=44, | |||
| time_window=10, | |||
| num_blocks=2): | |||
| super().__init__() | |||
| self.in_channels = in_channels | |||
| self.num_features = num_features | |||
| self.out_channels = out_channels | |||
| self.num_blocks = num_blocks | |||
| self.time_window = time_window | |||
| self.expand_bn = nn.BatchNorm1d(self.num_features, momentum=0.1) | |||
| self.conv1 = nn.Sequential( | |||
| nn.ReplicationPad1d(1), | |||
| nn.Conv1d( | |||
| self.in_channels, self.num_features, kernel_size=3, | |||
| bias=False), self.expand_bn, nn.ReLU(inplace=True), | |||
| nn.Dropout(p=0.25)) | |||
| self._make_blocks() | |||
| self.pad = nn.ReplicationPad1d(4) | |||
| self.relu = nn.ReLU(inplace=True) | |||
| self.drop = nn.Dropout(p=0.25) | |||
| self.reduce = nn.Conv1d( | |||
| self.num_features, self.num_features, kernel_size=self.time_window) | |||
| self.embedding_3d_1 = nn.Linear(in_channels // 2 * 3, 500) | |||
| self.embedding_3d_2 = nn.Linear(500, 500) | |||
| self.LReLU1 = nn.LeakyReLU() | |||
| self.LReLU2 = nn.LeakyReLU() | |||
| self.LReLU3 = nn.LeakyReLU() | |||
| self.out1 = nn.Linear(self.num_features + 500, self.num_features) | |||
| self.out2 = nn.Linear(self.num_features, self.out_channels) | |||
| def _make_blocks(self): | |||
| layers_conv = [] | |||
| layers_bn = [] | |||
| for i in range(self.num_blocks): | |||
| layers_conv.append( | |||
| nn.Conv1d( | |||
| self.num_features, | |||
| self.num_features, | |||
| kernel_size=5, | |||
| bias=False, | |||
| dilation=2)) | |||
| layers_bn.append(nn.BatchNorm1d(self.num_features)) | |||
| self.layers_conv = nn.ModuleList(layers_conv) | |||
| self.layers_bn = nn.ModuleList(layers_bn) | |||
| def set_bn_momentum(self, momentum): | |||
| self.expand_bn.momentum = momentum | |||
| for bn in self.layers_bn: | |||
| bn.momentum = momentum | |||
| def forward(self, p2ds, p3d): | |||
| """ | |||
| Args: | |||
| x - (B x T x J x C) | |||
| """ | |||
| B, T, C = p2ds.shape | |||
| x = p2ds.permute((0, 2, 1)) | |||
| x = self.conv1(x) | |||
| for i in range(self.num_blocks): | |||
| pre = x | |||
| x = self.pad(x) | |||
| x = self.layers_conv[i](x) | |||
| x = self.layers_bn[i](x) | |||
| x = self.drop(self.relu(x)) | |||
| x = pre + x | |||
| x_2d = self.relu(self.reduce(x)) | |||
| x_2d = x_2d.view(B, -1) | |||
| x_3d = self.LReLU1(self.embedding_3d_1(p3d)) | |||
| x = torch.cat((x_2d, x_3d), 1) | |||
| x = self.LReLU3(self.out1(x)) | |||
| x = self.out2(x) | |||
| return x | |||
| @@ -189,6 +189,16 @@ TASK_OUTPUTS = { | |||
| Tasks.body_2d_keypoints: | |||
| [OutputKeys.POSES, OutputKeys.SCORES, OutputKeys.BOXES], | |||
| # 3D human body keypoints detection result for single sample | |||
| # { | |||
| # "poses": [ | |||
| # [[x, y, z]*17], | |||
| # [[x, y, z]*17], | |||
| # [[x, y, z]*17] | |||
| # ] | |||
| # } | |||
| Tasks.body_3d_keypoints: [OutputKeys.POSES], | |||
| # video single object tracking result for single video | |||
| # { | |||
| # "boxes": [ | |||
| @@ -89,6 +89,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| 'damo/cv_diffusion_text-to-image-synthesis_tiny'), | |||
| Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints, | |||
| 'damo/cv_hrnetv2w32_body-2d-keypoints_image'), | |||
| Tasks.body_3d_keypoints: (Pipelines.body_3d_keypoints, | |||
| 'damo/cv_canonical_body-3d-keypoints_video'), | |||
| Tasks.face_detection: (Pipelines.face_detection, | |||
| 'damo/cv_resnet_facedetection_scrfd10gkps'), | |||
| Tasks.face_recognition: (Pipelines.face_recognition, | |||
| @@ -7,6 +7,7 @@ if TYPE_CHECKING: | |||
| from .action_recognition_pipeline import ActionRecognitionPipeline | |||
| from .animal_recognition_pipeline import AnimalRecognitionPipeline | |||
| from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline | |||
| from .body_3d_keypoints_pipeline import Body3DKeypointsPipeline | |||
| from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline | |||
| from .crowd_counting_pipeline import CrowdCountingPipeline | |||
| from .image_detection_pipeline import ImageDetectionPipeline | |||
| @@ -46,6 +47,7 @@ else: | |||
| 'action_recognition_pipeline': ['ActionRecognitionPipeline'], | |||
| 'animal_recognition_pipeline': ['AnimalRecognitionPipeline'], | |||
| 'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'], | |||
| 'body_3d_keypoints_pipeline': ['Body3DKeypointsPipeline'], | |||
| 'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'], | |||
| 'crowd_counting_pipeline': ['CrowdCountingPipeline'], | |||
| 'image_detection_pipeline': ['ImageDetectionPipeline'], | |||
| @@ -0,0 +1,213 @@ | |||
| import os | |||
| import os.path as osp | |||
| from typing import Any, Dict, List, Union | |||
| import cv2 | |||
| import numpy as np | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.body_3d_keypoints.body_3d_pose import ( | |||
| BodyKeypointsDetection3D, KeypointsTypes) | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Input, Model, Pipeline, Tensor | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| logger = get_logger() | |||
| def convert_2_h36m(joints, joints_nbr=15): | |||
| lst_mappings = [[0, 8], [1, 7], [2, 12], [3, 13], [4, 14], [5, 9], [6, 10], | |||
| [7, 11], [8, 1], [9, 2], [10, 3], [11, 4], [12, 5], | |||
| [13, 6], [14, 0]] | |||
| nbr, dim = joints.shape | |||
| h36m_joints = np.zeros((nbr, dim)) | |||
| for mapping in lst_mappings: | |||
| h36m_joints[mapping[1]] = joints[mapping[0]] | |||
| if joints_nbr == 17: | |||
| lst_mappings_17 = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], | |||
| [5, 5], [6, 6], [7, 8], [8, 10], [9, 11], | |||
| [10, 12], [11, 13], [12, 14], [13, 15], | |||
| [14, 16]]) | |||
| h36m_joints_17 = np.zeros((17, 2)) | |||
| h36m_joints_17[lst_mappings_17[:, 1]] = h36m_joints[lst_mappings_17[:, | |||
| 0]] | |||
| h36m_joints_17[7] = (h36m_joints_17[0] + h36m_joints_17[8]) * 0.5 | |||
| h36m_joints_17[9] = (h36m_joints_17[8] + h36m_joints_17[10]) * 0.5 | |||
| h36m_joints = h36m_joints_17 | |||
| return h36m_joints | |||
| def smooth_pts(cur_pts, pre_pts, bbox, smooth_x=15.0, smooth_y=15.0): | |||
| if pre_pts is None: | |||
| return cur_pts | |||
| w, h = bbox[1] - bbox[0] | |||
| if w == 0 or h == 0: | |||
| return cur_pts | |||
| size_pre = len(pre_pts) | |||
| size_cur = len(cur_pts) | |||
| if (size_pre == 0 or size_cur == 0): | |||
| return cur_pts | |||
| factor_x = -(smooth_x / w) | |||
| factor_y = -(smooth_y / w) | |||
| for i in range(size_cur): | |||
| w_x = np.exp(factor_x * np.abs(cur_pts[i][0] - pre_pts[i][0])) | |||
| w_y = np.exp(factor_y * np.abs(cur_pts[i][1] - pre_pts[i][1])) | |||
| cur_pts[i][0] = (1.0 - w_x) * cur_pts[i][0] + w_x * pre_pts[i][0] | |||
| cur_pts[i][1] = (1.0 - w_y) * cur_pts[i][1] + w_y * pre_pts[i][1] | |||
| return cur_pts | |||
| def smoothing(lst_kps, lst_bboxes, smooth_x=15.0, smooth_y=15.0): | |||
| assert lst_kps.shape[0] == lst_bboxes.shape[0] | |||
| lst_smoothed_kps = [] | |||
| prev_pts = None | |||
| for i in range(lst_kps.shape[0]): | |||
| smoothed_cur_kps = smooth_pts(lst_kps[i], prev_pts, | |||
| lst_bboxes[i][0:-1].reshape(2, 2), | |||
| smooth_x, smooth_y) | |||
| lst_smoothed_kps.append(smoothed_cur_kps) | |||
| prev_pts = smoothed_cur_kps | |||
| return np.array(lst_smoothed_kps) | |||
| def convert_2_h36m_data(lst_kps, lst_bboxes, joints_nbr=15): | |||
| lst_kps = lst_kps.squeeze() | |||
| lst_bboxes = lst_bboxes.squeeze() | |||
| assert lst_kps.shape[0] == lst_bboxes.shape[0] | |||
| lst_kps = smoothing(lst_kps, lst_bboxes) | |||
| keypoints = [] | |||
| for i in range(lst_kps.shape[0]): | |||
| h36m_joints_2d = convert_2_h36m(lst_kps[i], joints_nbr=joints_nbr) | |||
| keypoints.append(h36m_joints_2d) | |||
| return keypoints | |||
| @PIPELINES.register_module( | |||
| Tasks.body_3d_keypoints, module_name=Pipelines.body_3d_keypoints) | |||
| class Body3DKeypointsPipeline(Pipeline): | |||
| def __init__(self, model: Union[str, BodyKeypointsDetection3D], **kwargs): | |||
| """Human body 3D pose estimation. | |||
| Args: | |||
| model (Union[str, BodyKeypointsDetection3D]): model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model, **kwargs) | |||
| self.keypoint_model_3d = model if isinstance( | |||
| model, BodyKeypointsDetection3D) else Model.from_pretrained(model) | |||
| self.keypoint_model_3d.eval() | |||
| # init human body 2D keypoints detection pipeline | |||
| self.human_body_2d_kps_det_pipeline = 'damo/cv_hrnetv2w32_body-2d-keypoints_image' | |||
| self.human_body_2d_kps_detector = pipeline( | |||
| Tasks.body_2d_keypoints, | |||
| model=self.human_body_2d_kps_det_pipeline, | |||
| device='gpu' if torch.cuda.is_available() else 'cpu') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| video_frames = self.read_video_frames(input) | |||
| if 0 == len(video_frames): | |||
| res = {'success': False, 'msg': 'get video frame failed.'} | |||
| return res | |||
| all_2d_poses = [] | |||
| all_boxes_with_socre = [] | |||
| max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints | |||
| for i, frame in enumerate(video_frames): | |||
| kps_2d = self.human_body_2d_kps_detector(frame) | |||
| box = kps_2d['boxes'][ | |||
| 0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox | |||
| pose = kps_2d['poses'][0] # keypoints: [15, 2] | |||
| score = kps_2d['scores'][0] # keypoints: [15, 2] | |||
| all_2d_poses.append(pose) | |||
| all_boxes_with_socre.append( | |||
| list(np.array(box).reshape( | |||
| (-1))) + [score]) # construct to list with shape [5] | |||
| if (i + 1) >= max_frame: | |||
| break | |||
| all_2d_poses_np = np.array(all_2d_poses).reshape( | |||
| (len(all_2d_poses), 15, | |||
| 2)) # 15: 2d keypoints number, 2: keypoint coordinate (x, y) | |||
| all_boxes_np = np.array(all_boxes_with_socre).reshape( | |||
| (len(all_boxes_with_socre), 5)) # [x1, y1, x2, y2, score] | |||
| kps_2d_h36m_17 = convert_2_h36m_data( | |||
| all_2d_poses_np, | |||
| all_boxes_np, | |||
| joints_nbr=self.keypoint_model_3d.cfg.model.MODEL.IN_NUM_JOINTS) | |||
| kps_2d_h36m_17 = np.array(kps_2d_h36m_17) | |||
| res = {'success': True, 'input_2d_pts': kps_2d_h36m_17} | |||
| return res | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| if not input['success']: | |||
| res = {'success': False, 'msg': 'preprocess failed.'} | |||
| return res | |||
| input_2d_pts = input['input_2d_pts'] | |||
| outputs = self.keypoint_model_3d.preprocess(input_2d_pts) | |||
| outputs = self.keypoint_model_3d.forward(outputs) | |||
| res = dict({'success': True}, **outputs) | |||
| return res | |||
| def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: | |||
| res = {OutputKeys.POSES: []} | |||
| if not input['success']: | |||
| pass | |||
| else: | |||
| poses = input[KeypointsTypes.POSES_CAMERA] | |||
| res = {OutputKeys.POSES: poses.data.cpu().numpy()} | |||
| return res | |||
| def read_video_frames(self, video_url: Union[str, cv2.VideoCapture]): | |||
| """Read video from local video file or from a video stream URL. | |||
| Args: | |||
| video_url (str or cv2.VideoCapture): Video path or video stream. | |||
| Raises: | |||
| Exception: Open video fail. | |||
| Returns: | |||
| [nd.array]: List of video frames. | |||
| """ | |||
| frames = [] | |||
| if isinstance(video_url, str): | |||
| cap = cv2.VideoCapture(video_url) | |||
| if not cap.isOpened(): | |||
| raise Exception( | |||
| 'modelscope error: %s cannot be decoded by OpenCV.' % | |||
| (video_url)) | |||
| else: | |||
| cap = video_url | |||
| max_frame_num = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME | |||
| frame_idx = 0 | |||
| while True: | |||
| ret, frame = cap.read() | |||
| if not ret: | |||
| break | |||
| frame_idx += 1 | |||
| frames.append(frame) | |||
| if frame_idx >= max_frame_num: | |||
| break | |||
| cap.release() | |||
| return frames | |||
| @@ -24,6 +24,7 @@ class CVTasks(object): | |||
| human_object_interaction = 'human-object-interaction' | |||
| face_image_generation = 'face-image-generation' | |||
| body_2d_keypoints = 'body-2d-keypoints' | |||
| body_3d_keypoints = 'body-3d-keypoints' | |||
| general_recognition = 'general-recognition' | |||
| image_classification = 'image-classification' | |||
| @@ -0,0 +1,49 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import pdb | |||
| import unittest | |||
| import cv2 | |||
| import numpy as np | |||
| from PIL import Image | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class Body3DKeypointsTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_canonical_body-3d-keypoints_video' | |||
| self.test_video = 'data/test/videos/Walking.54138969.mp4' | |||
| def pipeline_inference(self, pipeline: Pipeline, pipeline_input): | |||
| output = pipeline(pipeline_input) | |||
| poses = np.array(output[OutputKeys.POSES]) | |||
| print(f'result 3d points shape {poses.shape}') | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub_with_video_file(self): | |||
| body_3d_keypoints = pipeline( | |||
| Tasks.body_3d_keypoints, model=self.model_id) | |||
| self.pipeline_inference(body_3d_keypoints, self.test_video) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_modelhub_with_video_stream(self): | |||
| body_3d_keypoints = pipeline(Tasks.body_3d_keypoints) | |||
| cap = cv2.VideoCapture(self.test_video) | |||
| if not cap.isOpened(): | |||
| raise Exception('modelscope error: %s cannot be decoded by OpenCV.' | |||
| % (self.test_video)) | |||
| self.pipeline_inference(body_3d_keypoints, cap) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_modelhub_default_model(self): | |||
| body_3d_keypoints = pipeline(Tasks.body_3d_keypoints) | |||
| self.pipeline_inference(body_3d_keypoints, self.test_video) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||