Browse Source

[to #42322933]修复pipeline串联时collate_fn异常

修复pipeline串联时collate_fn异常
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10457058
master
hanyuan.chy yingda.chen 3 years ago
parent
commit
de6d84cb97
2 changed files with 18 additions and 5 deletions
  1. +2
    -0
      modelscope/pipelines/base.py
  2. +16
    -5
      modelscope/pipelines/cv/body_3d_keypoints_pipeline.py

+ 2
- 0
modelscope/pipelines/base.py View File

@@ -433,6 +433,8 @@ def collate_fn(data, device):
if isinstance(data, dict) or isinstance(data, Mapping):
return type(data)({k: collate_fn(v, device) for k, v in data.items()})
elif isinstance(data, (tuple, list)):
if 0 == len(data):
return torch.Tensor([])
if isinstance(data[0], (int, float)):
return default_collate(data).to(device)
else:


+ 16
- 5
modelscope/pipelines/cv/body_3d_keypoints_pipeline.py View File

@@ -143,6 +143,13 @@ class Body3DKeypointsPipeline(Pipeline):
max_frame = self.keypoint_model_3d.cfg.model.INPUT.MAX_FRAME # max video frame number to be predicted 3D joints
for i, frame in enumerate(video_frames):
kps_2d = self.human_body_2d_kps_detector(frame)
if [] == kps_2d.get('boxes'):
res = {
'success': False,
'msg': f'fail to detect person at image frame {i}'
}
return res

box = kps_2d['boxes'][
0] # box: [[[x1, y1], [x2, y2]]], N human boxes per frame, [0] represent using first detected bbox
pose = kps_2d['keypoints'][0] # keypoints: [15, 2]
@@ -180,7 +187,15 @@ class Body3DKeypointsPipeline(Pipeline):
return res

def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]:
res = {OutputKeys.KEYPOINTS: [], OutputKeys.TIMESTAMPS: []}
output_video_path = kwargs.get('output_video', None)
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name

res = {
OutputKeys.KEYPOINTS: [],
OutputKeys.TIMESTAMPS: [],
OutputKeys.OUTPUT_VIDEO: output_video_path
}

if not input['success']:
pass
@@ -189,10 +204,6 @@ class Body3DKeypointsPipeline(Pipeline):
pred_3d_pose = poses.data.cpu().numpy()[
0] # [frame_num, joint_num, joint_dim]

output_video_path = kwargs.get('output_video', None)
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(
suffix='.mp4').name
if 'render' in self.keypoint_model_3d.cfg.keys():
self.render_prediction(pred_3d_pose, output_video_path)
res[OutputKeys.OUTPUT_VIDEO] = output_video_path


Loading…
Cancel
Save