From de6d84cb9781181dc49b8162fb1f5a23fe4c8993 Mon Sep 17 00:00:00 2001 From: "hanyuan.chy" Date: Thu, 20 Oct 2022 19:33:06 +0800 Subject: [PATCH] =?UTF-8?q?[to=20#42322933]=E4=BF=AE=E5=A4=8Dpipeline?= =?UTF-8?q?=E4=B8=B2=E8=81=94=E6=97=B6collate=5Ffn=E5=BC=82=E5=B8=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复pipeline串联时collate_fn异常 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10457058 --- modelscope/pipelines/base.py | 2 ++ .../cv/body_3d_keypoints_pipeline.py | 21 ++++++++++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index ea329be4..644749fc 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -433,6 +433,8 @@ def collate_fn(data, device): if isinstance(data, dict) or isinstance(data, Mapping): return type(data)({k: collate_fn(v, device) for k, v in data.items()}) elif isinstance(data, (tuple, list)): + if 0 == len(data): + return torch.Tensor([]) if isinstance(data[0], (int, float)): return default_collate(data).to(device) else: diff --git a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py index 3502915c..8522ceff 100644 --- a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py +++ b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py @@ -143,6 +143,13 @@ class Body3DKeypointsPipeline(Pipeline): max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints for i, frame in enumerate(video_frames): kps_2d = self.human_body_2d_kps_detector(frame) + if [] == kps_2d.get('boxes'): + res = { + 'success': False, + 'msg': f'fail to detect person at image frame {i}' + } + return res + box = kps_2d['boxes'][ 0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox pose = kps_2d['keypoints'][0] # keypoints: [15, 2] @@ -180,7 +187,15 @@ class Body3DKeypointsPipeline(Pipeline): return res def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]: - res = {OutputKeys.KEYPOINTS: [], OutputKeys.TIMESTAMPS: []} + output_video_path = kwargs.get('output_video', None) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name + + res = { + OutputKeys.KEYPOINTS: [], + OutputKeys.TIMESTAMPS: [], + OutputKeys.OUTPUT_VIDEO: output_video_path + } if not input['success']: pass @@ -189,10 +204,6 @@ class Body3DKeypointsPipeline(Pipeline): pred_3d_pose = poses.data.cpu().numpy()[ 0] # [frame_num, joint_num, joint_dim] - output_video_path = kwargs.get('output_video', None) - if output_video_path is None: - output_video_path = tempfile.NamedTemporaryFile( - suffix='.mp4').name if 'render' in self.keypoint_model_3d.cfg.keys(): self.render_prediction(pred_3d_pose, output_video_path) res[OutputKeys.OUTPUT_VIDEO] = output_video_path