From 7f468acca37f91c2cc900c62caa65c653ea514d7 Mon Sep 17 00:00:00 2001 From: "hanyuan.chy" Date: Sat, 1 Oct 2022 18:34:23 +0800 Subject: [PATCH] [to #42322933]style(license): add license + render result poses with video Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10263904 --- .../cv/body_3d_keypoints/body_3d_pose.py | 2 + .../canonical_pose_modules.py | 2 +- modelscope/outputs.py | 21 ++- .../cv/body_3d_keypoints_pipeline.py | 157 +++++++++++++++++- tests/pipelines/test_body_3d_keypoints.py | 19 ++- 5 files changed, 183 insertions(+), 18 deletions(-) diff --git a/modelscope/models/cv/body_3d_keypoints/body_3d_pose.py b/modelscope/models/cv/body_3d_keypoints/body_3d_pose.py index 87cd4962..3e920d12 100644 --- a/modelscope/models/cv/body_3d_keypoints/body_3d_pose.py +++ b/modelscope/models/cv/body_3d_keypoints/body_3d_pose.py @@ -1,3 +1,5 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + import logging import os.path as osp from typing import Any, Dict, List, Union diff --git a/modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py b/modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py index b3eac2e5..b7f0c4a3 100644 --- a/modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py +++ b/modelscope/models/cv/body_3d_keypoints/canonical_pose_modules.py @@ -1,4 +1,4 @@ -# The implementation is based on OSTrack, available at https://github.com/facebookresearch/VideoPose3D +# The implementation is based on VideoPose3D, available at https://github.com/facebookresearch/VideoPose3D import torch import torch.nn as nn diff --git a/modelscope/outputs.py b/modelscope/outputs.py index 52f3c47e..f13bbed9 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -21,6 +21,7 @@ class OutputKeys(object): POLYGONS = 'polygons' OUTPUT = 'output' OUTPUT_IMG = 'output_img' + OUTPUT_VIDEO = 'output_video' OUTPUT_PCM = 'output_pcm' IMG_EMBEDDING = 'img_embedding' SPO_LIST = 'spo_list' @@ -218,13 +219,21 @@ TASK_OUTPUTS = { # 3D human body keypoints detection result for single sample # { - # "poses": [ - # [[x, y, z]*17], - # [[x, y, z]*17], - # [[x, y, z]*17] - # ] + # "poses": [ # 3d pose coordinate in camera coordinate + # [[x, y, z]*17], # joints of per image + # [[x, y, z]*17], + # ... + # ], + # "timestamps": [ # timestamps of all frames + # "00:00:0.230", + # "00:00:0.560", + # "00:00:0.690", + # ], + # "output_video": "path_to_rendered_video" , this is optional + # and is only avaialbe when the "render" option is enabled. # } - Tasks.body_3d_keypoints: [OutputKeys.POSES], + Tasks.body_3d_keypoints: + [OutputKeys.POSES, OutputKeys.TIMESTAMPS, OutputKeys.OUTPUT_VIDEO], # 2D hand keypoints result for single sample # { diff --git a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py index e9e4e9e8..474c0e54 100644 --- a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py +++ b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py @@ -1,10 +1,19 @@ -import os +# Copyright (c) Alibaba, Inc. and its affiliates. + +import datetime import os.path as osp +import tempfile from typing import Any, Dict, List, Union import cv2 +import matplotlib +import matplotlib.pyplot as plt +import mpl_toolkits.mplot3d.axes3d as p3 import numpy as np import torch +from matplotlib import animation +from matplotlib.animation import writers +from matplotlib.ticker import MultipleLocator from modelscope.metainfo import Pipelines from modelscope.models.cv.body_3d_keypoints.body_3d_pose import ( @@ -16,6 +25,8 @@ from modelscope.pipelines.builder import PIPELINES from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger +matplotlib.use('Agg') + logger = get_logger() @@ -121,7 +132,13 @@ class Body3DKeypointsPipeline(Pipeline): device='gpu' if torch.cuda.is_available() else 'cpu') def preprocess(self, input: Input) -> Dict[str, Any]: - video_frames = self.read_video_frames(input) + video_url = input.get('input_video') + self.output_video_path = input.get('output_video_path') + if self.output_video_path is None: + self.output_video_path = tempfile.NamedTemporaryFile( + suffix='.mp4').name + + video_frames = self.read_video_frames(video_url) if 0 == len(video_frames): res = {'success': False, 'msg': 'get video frame failed.'} return res @@ -168,13 +185,21 @@ class Body3DKeypointsPipeline(Pipeline): return res def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: - res = {OutputKeys.POSES: []} + res = {OutputKeys.POSES: [], OutputKeys.TIMESTAMPS: []} if not input['success']: pass else: poses = input[KeypointsTypes.POSES_CAMERA] - res = {OutputKeys.POSES: poses.data.cpu().numpy()} + pred_3d_pose = poses.data.cpu().numpy()[ + 0] # [frame_num, joint_num, joint_dim] + + if 'render' in self.keypoint_model_3d.cfg.keys(): + self.render_prediction(pred_3d_pose) + res[OutputKeys.OUTPUT_VIDEO] = self.output_video_path + + res[OutputKeys.POSES] = pred_3d_pose + res[OutputKeys.TIMESTAMPS] = self.timestamps return res def read_video_frames(self, video_url: Union[str, cv2.VideoCapture]): @@ -189,7 +214,15 @@ class Body3DKeypointsPipeline(Pipeline): Returns: [nd.array]: List of video frames. """ + + def timestamp_format(seconds): + m, s = divmod(seconds, 60) + h, m = divmod(m, 60) + time = '%02d:%02d:%06.3f' % (h, m, s) + return time + frames = [] + self.timestamps = [] # for video render if isinstance(video_url, str): cap = cv2.VideoCapture(video_url) if not cap.isOpened(): @@ -199,15 +232,131 @@ class Body3DKeypointsPipeline(Pipeline): else: cap = video_url + self.fps = cap.get(cv2.CAP_PROP_FPS) + if self.fps is None or self.fps <= 0: + raise Exception('modelscope error: %s cannot get video fps info.' % + (video_url)) + max_frame_num = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME frame_idx = 0 while True: ret, frame = cap.read() if not ret: break + self.timestamps.append( + timestamp_format(seconds=frame_idx / self.fps)) frame_idx += 1 frames.append(frame) if frame_idx >= max_frame_num: break cap.release() return frames + + def render_prediction(self, pose3d_cam_rr): + """render predict result 3d poses. + + Args: + pose3d_cam_rr (nd.array): [frame_num, joint_num, joint_dim], 3d pose joints + + Returns: + """ + frame_num = pose3d_cam_rr.shape[0] + + left_points = [11, 12, 13, 4, 5, 6] # joints of left body + edges = [[0, 1], [0, 4], [0, 7], [1, 2], [4, 5], [5, 6], [2, + 3], [7, 8], + [8, 9], [8, 11], [8, 14], [14, 15], [15, 16], [11, 12], + [12, 13], [9, 10]] # connection between joints + + fig = plt.figure() + ax = p3.Axes3D(fig) + x_major_locator = MultipleLocator(0.5) + + ax.xaxis.set_major_locator(x_major_locator) + ax.yaxis.set_major_locator(x_major_locator) + ax.zaxis.set_major_locator(x_major_locator) + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + ax.set_xlim(-1, 1) + ax.set_ylim(-1, 1) + ax.set_zlim(-1, 1) + # view direction + azim = self.keypoint_model_3d.cfg.render.azim + elev = self.keypoint_model_3d.cfg.render.elev + ax.view_init(elev, azim) + + # init plot, essentially + x = pose3d_cam_rr[0, :, 0] + y = pose3d_cam_rr[0, :, 1] + z = pose3d_cam_rr[0, :, 2] + points, = ax.plot(x, y, z, 'r.') + + def renderBones(xs, ys, zs): + """render bones in skeleton + + Args: + xs (nd.array): [joint_num, joint_channel] + ys (nd.array): [joint_num, joint_channel] + zs (nd.array): [joint_num, joint_channel] + """ + bones = {} + for idx, edge in enumerate(edges): + index1, index2 = edge[0], edge[1] + if index1 in left_points: + edge_color = 'red' + else: + edge_color = 'blue' + connect = ax.plot([xs[index1], xs[index2]], + [ys[index1], ys[index2]], + [zs[index1], zs[index2]], + linewidth=2, + color=edge_color) # plot edge + bones[idx] = connect[0] + return bones + + bones = renderBones(x, y, z) + + def update(frame_idx, points, bones): + """update animation + + Args: + frame_idx (int): frame index + points (mpl_toolkits.mplot3d.art3d.Line3D): skeleton points ploter + bones (dict[int, mpl_toolkits.mplot3d.art3d.Line3D]): connection ploter + + Returns: + tuple: points and bones ploter + """ + xs = pose3d_cam_rr[frame_idx, :, 0] + ys = pose3d_cam_rr[frame_idx, :, 1] + zs = pose3d_cam_rr[frame_idx, :, 2] + + # update bones + for idx, edge in enumerate(edges): + index1, index2 = edge[0], edge[1] + x1x2 = (xs[index1], xs[index2]) + y1y2 = (ys[index1], ys[index2]) + z1z2 = (zs[index1], zs[index2]) + bones[idx].set_xdata(x1x2) + bones[idx].set_ydata(y1y2) + bones[idx].set_3d_properties(z1z2, 'z') + + # update joints + points.set_data(xs, ys) + points.set_3d_properties(zs, 'z') + if 0 == frame_idx / 100: + logger.info(f'rendering {frame_idx}/{frame_num}') + return points, bones + + ani = animation.FuncAnimation( + fig=fig, + func=update, + frames=frame_num, + interval=self.fps, + fargs=(points, bones)) + + # save mp4 + Writer = writers['ffmpeg'] + writer = Writer(fps=self.fps, metadata={}, bitrate=4096) + ani.save(self.output_video_path, writer=writer) diff --git a/tests/pipelines/test_body_3d_keypoints.py b/tests/pipelines/test_body_3d_keypoints.py index 9dce0d19..bde04f8e 100644 --- a/tests/pipelines/test_body_3d_keypoints.py +++ b/tests/pipelines/test_body_3d_keypoints.py @@ -28,7 +28,12 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck): def test_run_modelhub_with_video_file(self): body_3d_keypoints = pipeline( Tasks.body_3d_keypoints, model=self.model_id) - self.pipeline_inference(body_3d_keypoints, self.test_video) + pipeline_input = { + 'input_video': self.test_video, + 'output_video_path': './result.mp4' + } + self.pipeline_inference( + body_3d_keypoints, pipeline_input=pipeline_input) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_modelhub_with_video_stream(self): @@ -37,12 +42,12 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck): if not cap.isOpened(): raise Exception('modelscope error: %s cannot be decoded by OpenCV.' % (self.test_video)) - self.pipeline_inference(body_3d_keypoints, cap) - - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_run_modelhub_default_model(self): - body_3d_keypoints = pipeline(Tasks.body_3d_keypoints) - self.pipeline_inference(body_3d_keypoints, self.test_video) + pipeline_input = { + 'input_video': cap, + 'output_video_path': './result.mp4' + } + self.pipeline_inference( + body_3d_keypoints, pipeline_input=pipeline_input) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_demo_compatibility(self):