hanyuan.chy yingda.chen 3 years ago
parent
commit
00b448a2fe
12 changed files with 792 additions and 3 deletions
  1. +3
    -0
      data/test/videos/Walking.54138969.mp4
  2. +2
    -0
      modelscope/metainfo.py
  3. +8
    -3
      modelscope/models/cv/__init__.py
  4. +23
    -0
      modelscope/models/cv/body_3d_keypoints/__init__.py
  5. +246
    -0
      modelscope/models/cv/body_3d_keypoints/body_3d_pose.py
  6. +233
    -0
      modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py
  7. +10
    -0
      modelscope/outputs.py
  8. +2
    -0
      modelscope/pipelines/builder.py
  9. +2
    -0
      modelscope/pipelines/cv/__init__.py
  10. +213
    -0
      modelscope/pipelines/cv/body_3d_keypoints_pipeline.py
  11. +1
    -0
      modelscope/utils/constant.py
  12. +49
    -0
      tests/pipelines/test_body_3d_keypoints.py

+ 3
- 0
data/test/videos/Walking.54138969.mp4 View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7b8f50a0537bfe7e082c5ad91b2b7ece61a0adbeb7489988e553909276bf920c
size 44217644

+ 2
- 0
modelscope/metainfo.py View File

@@ -20,6 +20,7 @@ class Models(object):
gpen = 'gpen'
product_retrieval_embedding = 'product-retrieval-embedding'
body_2d_keypoints = 'body-2d-keypoints'
body_3d_keypoints = 'body-3d-keypoints'
crowd_counting = 'HRNetCrowdCounting'
panoptic_segmentation = 'swinL-panoptic-segmentation'
image_reid_person = 'passvitb'
@@ -95,6 +96,7 @@ class Pipelines(object):
general_recognition = 'resnet101-general-recognition'
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image'
body_3d_keypoints = 'canonical_body-3d-keypoints_video'
human_detection = 'resnet18-human-detection'
object_detection = 'vit-object-detection'
easycv_detection = 'easycv-detection'


+ 8
- 3
modelscope/models/cv/__init__.py View File

@@ -1,11 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

# yapf: disable
from . import (action_recognition, animal_recognition, body_2d_keypoints,
cartoon, cmdssl_video_embedding, crowd_counting, face_detection,
face_generation, image_classification, image_color_enhance,
image_colorization, image_denoise, image_instance_segmentation,
body_3d_keypoints, cartoon, cmdssl_video_embedding,
crowd_counting, face_detection, face_generation,
image_classification, image_color_enhance, image_colorization,
image_denoise, image_instance_segmentation,
image_panoptic_segmentation, image_portrait_enhancement,
image_reid_person, image_semantic_segmentation,
image_to_image_generation, image_to_image_translation,
object_detection, product_retrieval_embedding,
realtime_object_detection, salient_detection, super_resolution,
video_single_object_tracking, video_summarization, virual_tryon)

# yapf: enable

+ 23
- 0
modelscope/models/cv/body_3d_keypoints/__init__.py View File

@@ -0,0 +1,23 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:

from .body_3d_pose import BodyKeypointsDetection3D

else:
_import_structure = {
'body_3d_pose': ['BodyKeypointsDetection3D'],
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 246
- 0
modelscope/models/cv/body_3d_keypoints/body_3d_pose.py View File

@@ -0,0 +1,246 @@
import logging
import os.path as osp
from typing import Any, Dict, List, Union

import numpy as np
import torch

from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.body_3d_keypoints.canonical_pose_modules import (
TemporalModel, TransCan3Dkeys)
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()

__all__ = ['BodyKeypointsDetection3D']


class KeypointsTypes(object):
POSES_CAMERA = 'poses_camera'
POSES_TRAJ = 'poses_traj'


@MODELS.register_module(
Tasks.body_3d_keypoints, module_name=Models.body_3d_keypoints)
class BodyKeypointsDetection3D(TorchModel):

def __init__(self, model_dir: str, *args, **kwargs):

super().__init__(model_dir, *args, **kwargs)

self.model_dir = model_dir
model_path = osp.join(self.model_dir, ModelFile.TORCH_MODEL_FILE)
cfg_path = osp.join(self.model_dir, ModelFile.CONFIGURATION)
self.cfg = Config.from_file(cfg_path)
self._create_model()

if not osp.exists(model_path):
raise IOError(f'{model_path} is not exists.')

if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self.pretrained_state_dict = torch.load(
model_path, map_location=self._device)

self.load_pretrained()
self.to_device(self._device)
self.eval()

def _create_model(self):
self.model_pos = TemporalModel(
self.cfg.model.MODEL.IN_NUM_JOINTS,
self.cfg.model.MODEL.IN_2D_FEATURE,
self.cfg.model.MODEL.OUT_NUM_JOINTS,
filter_widths=self.cfg.model.MODEL.FILTER_WIDTHS,
causal=self.cfg.model.MODEL.CAUSAL,
dropout=self.cfg.model.MODEL.DROPOUT,
channels=self.cfg.model.MODEL.CHANNELS,
dense=self.cfg.model.MODEL.DENSE)

receptive_field = self.model_pos.receptive_field()
self.pad = (receptive_field - 1) // 2
if self.cfg.model.MODEL.CAUSAL:
self.causal_shift = self.pad
else:
self.causal_shift = 0

self.model_traj = TransCan3Dkeys(
in_channels=self.cfg.model.MODEL.IN_NUM_JOINTS
* self.cfg.model.MODEL.IN_2D_FEATURE,
num_features=1024,
out_channels=self.cfg.model.MODEL.OUT_3D_FEATURE,
num_blocks=4,
time_window=receptive_field)

def eval(self):
self.model_pos.eval()
self.model_traj.eval()

def train(self):
self.model_pos.train()
self.model_traj.train()

def to_device(self, device):
self.model_pos = self.model_pos.to(device)
self.model_traj = self.model_traj.to(device)

def load_pretrained(self):
if 'model_pos' in self.pretrained_state_dict:
self.model_pos.load_state_dict(
self.pretrained_state_dict['model_pos'], strict=False)
else:
logging.error(
'Not load model pos from pretrained_state_dict, not in pretrained_state_dict'
)

if 'model_traj' in self.pretrained_state_dict:
self.model_traj.load_state_dict(
self.pretrained_state_dict['model_traj'], strict=False)
else:
logging.error(
'Not load model traj from pretrained_state_dict, not in pretrained_state_dict'
)
logging.info('Load pretrained model done.')

def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""Proprocess of 2D input joints.

Args:
input (Dict[str, Any]): [NUM_FRAME, NUM_JOINTS, 2], input 2d human body keypoints.

Returns:
Dict[str, Any]: canonical 2d points and root relative joints.
"""
if 'cuda' == input.device.type:
input = input.data.cpu().numpy()
elif 'cpu' == input.device.type:
input = input.data.numpy()
pose2d = input

pose2d_canonical = self.canonicalize_2Ds(
pose2d, self.cfg.model.INPUT.FOCAL_LENGTH,
self.cfg.model.INPUT.CENTER)
pose2d_normalized = self.normalize_screen_coordinates(
pose2d, self.cfg.model.INPUT.RES_W, self.cfg.model.INPUT.RES_H)
pose2d_rr = pose2d_normalized
pose2d_rr[:, 1:] -= pose2d_rr[:, :1]

# expand [NUM_FRAME, NUM_JOINTS, 2] to [1, NUM_FRAME, NUM_JOINTS, 2]
pose2d_rr = np.expand_dims(
np.pad(
pose2d_rr,
((self.pad + self.causal_shift, self.pad - self.causal_shift),
(0, 0), (0, 0)), 'edge'),
axis=0)
pose2d_canonical = np.expand_dims(
np.pad(
pose2d_canonical,
((self.pad + self.causal_shift, self.pad - self.causal_shift),
(0, 0), (0, 0)), 'edge'),
axis=0)
pose2d_rr = torch.from_numpy(pose2d_rr.astype(np.float32))
pose2d_canonical = torch.from_numpy(
pose2d_canonical.astype(np.float32))

inputs_2d = pose2d_rr.clone()
if torch.cuda.is_available():
inputs_2d = inputs_2d.cuda(non_blocking=True)

# Positional model
if self.cfg.model.MODEL.USE_2D_OFFSETS:
inputs_2d[:, :, 0] = 0
else:
inputs_2d[:, :, 1:] += inputs_2d[:, :, :1]

return {
'inputs_2d': inputs_2d,
'pose2d_rr': pose2d_rr,
'pose2d_canonical': pose2d_canonical
}

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""3D human pose estimation.

Args:
input (Dict):
inputs_2d: [1, NUM_FRAME, NUM_JOINTS, 2]
pose2d_rr: [1, NUM_FRAME, NUM_JOINTS, 2]
pose2d_canonical: [1, NUM_FRAME, NUM_JOINTS, 2]
NUM_FRAME = max(receptive_filed + video_frame_number, video_frame_number)

Returns:
Dict[str, Any]:
"camera_pose": Tensor, [1, NUM_FRAME, OUT_NUM_JOINTS, OUT_3D_FEATURE_DIM],
3D human pose keypoints in camera frame.
"camera_traj": Tensor, [1, NUM_FRAME, 1, 3],
root keypoints coordinates in camere frame.
"""
inputs_2d = input['inputs_2d']
pose2d_rr = input['pose2d_rr']
pose2d_canonical = input['pose2d_canonical']
with torch.no_grad():
# predict 3D pose keypoints
predicted_3d_pos = self.model_pos(inputs_2d)

# predict global trajectory
b1, w1, n1, d1 = inputs_2d.shape

input_pose2d_abs = self.get_abs_2d_pts(w1, pose2d_rr,
pose2d_canonical)
b1, w1, n1, d1 = input_pose2d_abs.size()
b2, w2, n2, d2 = predicted_3d_pos.size()

if torch.cuda.is_available():
input_pose2d_abs = input_pose2d_abs.cuda(non_blocking=True)

predicted_3d_traj = self.model_traj(
input_pose2d_abs.view(b1, w1, n1 * d1),
predicted_3d_pos.view(b2 * w2, n2 * d2)).view(b2, w2, -1, 3)

predict_dict = {
KeypointsTypes.POSES_CAMERA: predicted_3d_pos,
KeypointsTypes.POSES_TRAJ: predicted_3d_traj
}

return predict_dict

def get_abs_2d_pts(self, input_video_frame_num, pose2d_rr,
pose2d_canonical):
pad = self.pad
w = input_video_frame_num - pad * 2

lst_pose2d_rr = []
lst_pose2d_cannoical = []
for i in range(pad, w + pad):
lst_pose2d_rr.append(pose2d_rr[:, i - pad:i + pad + 1])
lst_pose2d_cannoical.append(pose2d_canonical[:,
i - pad:i + pad + 1])

input_pose2d_rr = torch.concat(lst_pose2d_cannoical, axis=0)
input_pose2d_cannoical = torch.concat(lst_pose2d_cannoical, axis=0)

if self.cfg.model.MODEL.USE_CANONICAL_COORDS:
input_pose2d_abs = input_pose2d_cannoical.clone()
else:
input_pose2d_abs = input_pose2d_rr.clone()
input_pose2d_abs[:, :, 1:] += input_pose2d_abs[:, :, :1]

return input_pose2d_abs

def canonicalize_2Ds(self, pos2d, f, c):
cs = np.array([c[0], c[1]]).reshape(1, 1, 2)
fs = np.array([f[0], f[1]]).reshape(1, 1, 2)
canoical_2Ds = (pos2d - cs) / fs
return canoical_2Ds

def normalize_screen_coordinates(self, X, w, h):
assert X.shape[-1] == 2

# Normalize so that [0, w] is mapped to [-1, 1], while preserving the aspect ratio
return X / w * 2 - [1, h / w]

+ 233
- 0
modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py View File

@@ -0,0 +1,233 @@
# The implementation is based on OSTrack, available at https://github.com/facebookresearch/VideoPose3D
import torch
import torch.nn as nn


class TemporalModelBase(nn.Module):
"""
Do not instantiate this class.
"""

def __init__(self, num_joints_in, in_features, num_joints_out,
filter_widths, causal, dropout, channels):
super().__init__()

# Validate input
for fw in filter_widths:
assert fw % 2 != 0, 'Only odd filter widths are supported'

self.num_joints_in = num_joints_in
self.in_features = in_features
self.num_joints_out = num_joints_out
self.filter_widths = filter_widths

self.drop = nn.Dropout(dropout)
self.relu = nn.ReLU(inplace=True)

self.pad = [filter_widths[0] // 2]
self.expand_bn = nn.BatchNorm1d(channels, momentum=0.1)
self.shrink = nn.Conv1d(channels, num_joints_out * 3, 1)

def set_bn_momentum(self, momentum):
self.expand_bn.momentum = momentum
for bn in self.layers_bn:
bn.momentum = momentum

def receptive_field(self):
"""
Return the total receptive field of this model as # of frames.
"""
frames = 0
for f in self.pad:
frames += f
return 1 + 2 * frames

def total_causal_shift(self):
"""
Return the asymmetric offset for sequence padding.
The returned value is typically 0 if causal convolutions are disabled,
otherwise it is half the receptive field.
"""
frames = self.causal_shift[0]
next_dilation = self.filter_widths[0]
for i in range(1, len(self.filter_widths)):
frames += self.causal_shift[i] * next_dilation
next_dilation *= self.filter_widths[i]
return frames

def forward(self, x):
assert len(x.shape) == 4
assert x.shape[-2] == self.num_joints_in
assert x.shape[-1] == self.in_features

sz = x.shape[:3]
x = x.view(x.shape[0], x.shape[1], -1)
x = x.permute(0, 2, 1)

x = self._forward_blocks(x)

x = x.permute(0, 2, 1)
x = x.view(sz[0], -1, self.num_joints_out, 3)

return x


class TemporalModel(TemporalModelBase):
"""
Reference 3D pose estimation model with temporal convolutions.
This implementation can be used for all use-cases.
"""

def __init__(self,
num_joints_in,
in_features,
num_joints_out,
filter_widths,
causal=False,
dropout=0.25,
channels=1024,
dense=False):
"""
Initialize this model.

Arguments:
num_joints_in -- number of input joints (e.g. 17 for Human3.6M)
in_features -- number of input features for each joint (typically 2 for 2D input)
num_joints_out -- number of output joints (can be different than input)
filter_widths -- list of convolution widths, which also determines the # of blocks and receptive field
causal -- use causal convolutions instead of symmetric convolutions (for real-time applications)
dropout -- dropout probability
channels -- number of convolution channels
dense -- use regular dense convolutions instead of dilated convolutions (ablation experiment)
"""
super().__init__(num_joints_in, in_features, num_joints_out,
filter_widths, causal, dropout, channels)

self.expand_conv = nn.Conv1d(
num_joints_in * in_features,
channels,
filter_widths[0],
bias=False)

layers_conv = []
layers_bn = []

self.causal_shift = [(filter_widths[0]) // 2 if causal else 0]
next_dilation = filter_widths[0]
for i in range(1, len(filter_widths)):
self.pad.append((filter_widths[i] - 1) * next_dilation // 2)
self.causal_shift.append((filter_widths[i] // 2
* next_dilation) if causal else 0)

layers_conv.append(
nn.Conv1d(
channels,
channels,
filter_widths[i] if not dense else (2 * self.pad[-1] + 1),
dilation=next_dilation if not dense else 1,
bias=False))
layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))
layers_conv.append(
nn.Conv1d(channels, channels, 1, dilation=1, bias=False))
layers_bn.append(nn.BatchNorm1d(channels, momentum=0.1))

next_dilation *= filter_widths[i]

self.layers_conv = nn.ModuleList(layers_conv)
self.layers_bn = nn.ModuleList(layers_bn)

def _forward_blocks(self, x):
x = self.drop(self.relu(self.expand_bn(self.expand_conv(x))))
for i in range(len(self.pad) - 1):
pad = self.pad[i + 1]
shift = self.causal_shift[i + 1]
res = x[:, :, pad + shift:x.shape[2] - pad + shift]
x = self.drop(
self.relu(self.layers_bn[2 * i](self.layers_conv[2 * i](x))))
x = res + self.drop(
self.relu(self.layers_bn[2 * i + 1](
self.layers_conv[2 * i + 1](x))))

x = self.shrink(x)
return x


# regression of the trajectory
class TransCan3Dkeys(nn.Module):

def __init__(self,
in_channels=74,
num_features=256,
out_channels=44,
time_window=10,
num_blocks=2):
super().__init__()
self.in_channels = in_channels
self.num_features = num_features
self.out_channels = out_channels
self.num_blocks = num_blocks
self.time_window = time_window

self.expand_bn = nn.BatchNorm1d(self.num_features, momentum=0.1)
self.conv1 = nn.Sequential(
nn.ReplicationPad1d(1),
nn.Conv1d(
self.in_channels, self.num_features, kernel_size=3,
bias=False), self.expand_bn, nn.ReLU(inplace=True),
nn.Dropout(p=0.25))
self._make_blocks()
self.pad = nn.ReplicationPad1d(4)
self.relu = nn.ReLU(inplace=True)
self.drop = nn.Dropout(p=0.25)
self.reduce = nn.Conv1d(
self.num_features, self.num_features, kernel_size=self.time_window)
self.embedding_3d_1 = nn.Linear(in_channels // 2 * 3, 500)
self.embedding_3d_2 = nn.Linear(500, 500)
self.LReLU1 = nn.LeakyReLU()
self.LReLU2 = nn.LeakyReLU()
self.LReLU3 = nn.LeakyReLU()
self.out1 = nn.Linear(self.num_features + 500, self.num_features)
self.out2 = nn.Linear(self.num_features, self.out_channels)

def _make_blocks(self):
layers_conv = []
layers_bn = []
for i in range(self.num_blocks):
layers_conv.append(
nn.Conv1d(
self.num_features,
self.num_features,
kernel_size=5,
bias=False,
dilation=2))
layers_bn.append(nn.BatchNorm1d(self.num_features))
self.layers_conv = nn.ModuleList(layers_conv)
self.layers_bn = nn.ModuleList(layers_bn)

def set_bn_momentum(self, momentum):
self.expand_bn.momentum = momentum
for bn in self.layers_bn:
bn.momentum = momentum

def forward(self, p2ds, p3d):
"""
Args:
x - (B x T x J x C)
"""
B, T, C = p2ds.shape
x = p2ds.permute((0, 2, 1))
x = self.conv1(x)
for i in range(self.num_blocks):
pre = x
x = self.pad(x)
x = self.layers_conv[i](x)
x = self.layers_bn[i](x)
x = self.drop(self.relu(x))
x = pre + x
x_2d = self.relu(self.reduce(x))
x_2d = x_2d.view(B, -1)
x_3d = self.LReLU1(self.embedding_3d_1(p3d))
x = torch.cat((x_2d, x_3d), 1)
x = self.LReLU3(self.out1(x))
x = self.out2(x)
return x

+ 10
- 0
modelscope/outputs.py View File

@@ -189,6 +189,16 @@ TASK_OUTPUTS = {
Tasks.body_2d_keypoints:
[OutputKeys.POSES, OutputKeys.SCORES, OutputKeys.BOXES],

# 3D human body keypoints detection result for single sample
# {
# "poses": [
# [[x, y, z]*17],
# [[x, y, z]*17],
# [[x, y, z]*17]
# ]
# }
Tasks.body_3d_keypoints: [OutputKeys.POSES],

# video single object tracking result for single video
# {
# "boxes": [


+ 2
- 0
modelscope/pipelines/builder.py View File

@@ -89,6 +89,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_diffusion_text-to-image-synthesis_tiny'),
Tasks.body_2d_keypoints: (Pipelines.body_2d_keypoints,
'damo/cv_hrnetv2w32_body-2d-keypoints_image'),
Tasks.body_3d_keypoints: (Pipelines.body_3d_keypoints,
'damo/cv_canonical_body-3d-keypoints_video'),
Tasks.face_detection: (Pipelines.face_detection,
'damo/cv_resnet_facedetection_scrfd10gkps'),
Tasks.face_recognition: (Pipelines.face_recognition,


+ 2
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -7,6 +7,7 @@ if TYPE_CHECKING:
from .action_recognition_pipeline import ActionRecognitionPipeline
from .animal_recognition_pipeline import AnimalRecognitionPipeline
from .body_2d_keypoints_pipeline import Body2DKeypointsPipeline
from .body_3d_keypoints_pipeline import Body3DKeypointsPipeline
from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline
from .crowd_counting_pipeline import CrowdCountingPipeline
from .image_detection_pipeline import ImageDetectionPipeline
@@ -46,6 +47,7 @@ else:
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
'animal_recognition_pipeline': ['AnimalRecognitionPipeline'],
'body_2d_keypoints_pipeline': ['Body2DKeypointsPipeline'],
'body_3d_keypoints_pipeline': ['Body3DKeypointsPipeline'],
'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'],
'crowd_counting_pipeline': ['CrowdCountingPipeline'],
'image_detection_pipeline': ['ImageDetectionPipeline'],


+ 213
- 0
modelscope/pipelines/cv/body_3d_keypoints_pipeline.py View File

@@ -0,0 +1,213 @@
import os
import os.path as osp
from typing import Any, Dict, List, Union

import cv2
import numpy as np
import torch

from modelscope.metainfo import Pipelines
from modelscope.models.cv.body_3d_keypoints.body_3d_pose import (
BodyKeypointsDetection3D, KeypointsTypes)
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Input, Model, Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger

logger = get_logger()


def convert_2_h36m(joints, joints_nbr=15):
lst_mappings = [[0, 8], [1, 7], [2, 12], [3, 13], [4, 14], [5, 9], [6, 10],
[7, 11], [8, 1], [9, 2], [10, 3], [11, 4], [12, 5],
[13, 6], [14, 0]]
nbr, dim = joints.shape
h36m_joints = np.zeros((nbr, dim))
for mapping in lst_mappings:
h36m_joints[mapping[1]] = joints[mapping[0]]

if joints_nbr == 17:
lst_mappings_17 = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4],
[5, 5], [6, 6], [7, 8], [8, 10], [9, 11],
[10, 12], [11, 13], [12, 14], [13, 15],
[14, 16]])
h36m_joints_17 = np.zeros((17, 2))
h36m_joints_17[lst_mappings_17[:, 1]] = h36m_joints[lst_mappings_17[:,
0]]
h36m_joints_17[7] = (h36m_joints_17[0] + h36m_joints_17[8]) * 0.5
h36m_joints_17[9] = (h36m_joints_17[8] + h36m_joints_17[10]) * 0.5
h36m_joints = h36m_joints_17

return h36m_joints


def smooth_pts(cur_pts, pre_pts, bbox, smooth_x=15.0, smooth_y=15.0):
if pre_pts is None:
return cur_pts

w, h = bbox[1] - bbox[0]
if w == 0 or h == 0:
return cur_pts

size_pre = len(pre_pts)
size_cur = len(cur_pts)
if (size_pre == 0 or size_cur == 0):
return cur_pts

factor_x = -(smooth_x / w)
factor_y = -(smooth_y / w)

for i in range(size_cur):
w_x = np.exp(factor_x * np.abs(cur_pts[i][0] - pre_pts[i][0]))
w_y = np.exp(factor_y * np.abs(cur_pts[i][1] - pre_pts[i][1]))
cur_pts[i][0] = (1.0 - w_x) * cur_pts[i][0] + w_x * pre_pts[i][0]
cur_pts[i][1] = (1.0 - w_y) * cur_pts[i][1] + w_y * pre_pts[i][1]
return cur_pts


def smoothing(lst_kps, lst_bboxes, smooth_x=15.0, smooth_y=15.0):
assert lst_kps.shape[0] == lst_bboxes.shape[0]

lst_smoothed_kps = []
prev_pts = None
for i in range(lst_kps.shape[0]):
smoothed_cur_kps = smooth_pts(lst_kps[i], prev_pts,
lst_bboxes[i][0:-1].reshape(2, 2),
smooth_x, smooth_y)
lst_smoothed_kps.append(smoothed_cur_kps)
prev_pts = smoothed_cur_kps

return np.array(lst_smoothed_kps)


def convert_2_h36m_data(lst_kps, lst_bboxes, joints_nbr=15):
lst_kps = lst_kps.squeeze()
lst_bboxes = lst_bboxes.squeeze()

assert lst_kps.shape[0] == lst_bboxes.shape[0]

lst_kps = smoothing(lst_kps, lst_bboxes)

keypoints = []
for i in range(lst_kps.shape[0]):
h36m_joints_2d = convert_2_h36m(lst_kps[i], joints_nbr=joints_nbr)
keypoints.append(h36m_joints_2d)
return keypoints


@PIPELINES.register_module(
Tasks.body_3d_keypoints, module_name=Pipelines.body_3d_keypoints)
class Body3DKeypointsPipeline(Pipeline):

def __init__(self, model: Union[str, BodyKeypointsDetection3D], **kwargs):
"""Human body 3D pose estimation.

Args:
model (Union[str, BodyKeypointsDetection3D]): model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)

self.keypoint_model_3d = model if isinstance(
model, BodyKeypointsDetection3D) else Model.from_pretrained(model)
self.keypoint_model_3d.eval()

# init human body 2D keypoints detection pipeline
self.human_body_2d_kps_det_pipeline = 'damo/cv_hrnetv2w32_body-2d-keypoints_image'
self.human_body_2d_kps_detector = pipeline(
Tasks.body_2d_keypoints,
model=self.human_body_2d_kps_det_pipeline,
device='gpu' if torch.cuda.is_available() else 'cpu')

def preprocess(self, input: Input) -> Dict[str, Any]:
video_frames = self.read_video_frames(input)
if 0 == len(video_frames):
res = {'success': False, 'msg': 'get video frame failed.'}
return res

all_2d_poses = []
all_boxes_with_socre = []
max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints
for i, frame in enumerate(video_frames):
kps_2d = self.human_body_2d_kps_detector(frame)
box = kps_2d['boxes'][
0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox
pose = kps_2d['poses'][0] # keypoints: [15, 2]
score = kps_2d['scores'][0] # keypoints: [15, 2]
all_2d_poses.append(pose)
all_boxes_with_socre.append(
list(np.array(box).reshape(
(-1))) + [score]) # construct to list with shape [5]
if (i + 1) >= max_frame:
break

all_2d_poses_np = np.array(all_2d_poses).reshape(
(len(all_2d_poses), 15,
2)) # 15: 2d keypoints number, 2: keypoint coordinate (x, y)
all_boxes_np = np.array(all_boxes_with_socre).reshape(
(len(all_boxes_with_socre), 5)) # [x1, y1, x2, y2, score]

kps_2d_h36m_17 = convert_2_h36m_data(
all_2d_poses_np,
all_boxes_np,
joints_nbr=self.keypoint_model_3d.cfg.model.MODEL.IN_NUM_JOINTS)
kps_2d_h36m_17 = np.array(kps_2d_h36m_17)
res = {'success': True, 'input_2d_pts': kps_2d_h36m_17}
return res

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
if not input['success']:
res = {'success': False, 'msg': 'preprocess failed.'}
return res

input_2d_pts = input['input_2d_pts']
outputs = self.keypoint_model_3d.preprocess(input_2d_pts)
outputs = self.keypoint_model_3d.forward(outputs)
res = dict({'success': True}, **outputs)
return res

def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]:
res = {OutputKeys.POSES: []}

if not input['success']:
pass
else:
poses = input[KeypointsTypes.POSES_CAMERA]
res = {OutputKeys.POSES: poses.data.cpu().numpy()}
return res

def read_video_frames(self, video_url: Union[str, cv2.VideoCapture]):
"""Read video from local video file or from a video stream URL.

Args:
video_url (str or cv2.VideoCapture): Video path or video stream.

Raises:
Exception: Open video fail.

Returns:
[nd.array]: List of video frames.
"""
frames = []
if isinstance(video_url, str):
cap = cv2.VideoCapture(video_url)
if not cap.isOpened():
raise Exception(
'modelscope error: %s cannot be decoded by OpenCV.' %
(video_url))
else:
cap = video_url

max_frame_num = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
frame_idx += 1
frames.append(frame)
if frame_idx >= max_frame_num:
break
cap.release()
return frames

+ 1
- 0
modelscope/utils/constant.py View File

@@ -24,6 +24,7 @@ class CVTasks(object):
human_object_interaction = 'human-object-interaction'
face_image_generation = 'face-image-generation'
body_2d_keypoints = 'body-2d-keypoints'
body_3d_keypoints = 'body-3d-keypoints'
general_recognition = 'general-recognition'

image_classification = 'image-classification'


+ 49
- 0
tests/pipelines/test_body_3d_keypoints.py View File

@@ -0,0 +1,49 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import pdb
import unittest

import cv2
import numpy as np
from PIL import Image

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class Body3DKeypointsTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_canonical_body-3d-keypoints_video'
self.test_video = 'data/test/videos/Walking.54138969.mp4'

def pipeline_inference(self, pipeline: Pipeline, pipeline_input):
output = pipeline(pipeline_input)
poses = np.array(output[OutputKeys.POSES])
print(f'result 3d points shape {poses.shape}')

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub_with_video_file(self):
body_3d_keypoints = pipeline(
Tasks.body_3d_keypoints, model=self.model_id)
self.pipeline_inference(body_3d_keypoints, self.test_video)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub_with_video_stream(self):
body_3d_keypoints = pipeline(Tasks.body_3d_keypoints)
cap = cv2.VideoCapture(self.test_video)
if not cap.isOpened():
raise Exception('modelscope error: %s cannot be decoded by OpenCV.'
% (self.test_video))
self.pipeline_inference(body_3d_keypoints, cap)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
body_3d_keypoints = pipeline(Tasks.body_3d_keypoints)
self.pipeline_inference(body_3d_keypoints, self.test_video)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save