Browse Source

[to #44902165] feat: Add input signature for pipeline

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10233235

* add input definition for pipeline


* support multiple input format for single pipeline


*  change param for body_3d_keypoints
master
wenmeng.zwm 3 years ago
parent
commit
65bea053af
6 changed files with 296 additions and 34 deletions
  1. +236
    -0
      modelscope/pipeline_inputs.py
  2. +38
    -1
      modelscope/pipelines/base.py
  3. +10
    -11
      modelscope/pipelines/cv/body_3d_keypoints_pipeline.py
  4. +3
    -9
      tests/pipelines/test_body_3d_keypoints.py
  5. +2
    -2
      tests/pipelines/test_mplug_tasks.py
  6. +7
    -11
      tests/pipelines/test_ofa_tasks.py

+ 236
- 0
modelscope/pipeline_inputs.py View File

@@ -0,0 +1,236 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import cv2
import numpy as np
from PIL import Image

from modelscope.models.base.base_head import Input
from modelscope.utils.constant import Tasks


class InputKeys(object):
IMAGE = 'image'
TEXT = 'text'
VIDEO = 'video'


class InputType(object):
IMAGE = 'image'
TEXT = 'text'
AUDIO = 'audio'
VIDEO = 'video'
BOX = 'box'
DICT = 'dict'
LIST = 'list'
INT = 'int'


INPUT_TYPE = {
InputType.IMAGE: (str, np.ndarray, Image.Image),
InputType.TEXT: str,
InputType.AUDIO: (str, np.ndarray),
InputType.VIDEO: (str, np.ndarray, cv2.VideoCapture),
InputType.BOX: (list, np.ndarray),
InputType.DICT: (dict, type(None)),
InputType.LIST: (list, type(None)),
InputType.INT: int,
}


def check_input_type(input_type, input):
expected_type = INPUT_TYPE[input_type]
assert isinstance(input, expected_type), \
f'invalid input type for {input_type}, expected {expected_type} but got {type(input)}\n {input}'


TASK_INPUTS = {
# if task input is single var, value is InputType
# if task input is a tuple, value is tuple of InputType
# if task input is a dict, value is a dict of InputType, where key
# equals the one needed in pipeline input dict
# if task input is a list, value is a set of input format, in which
# each elements corresponds to one input format as described above.
# ============ vision tasks ===================
Tasks.ocr_detection:
InputType.IMAGE,
Tasks.ocr_recognition:
InputType.IMAGE,
Tasks.face_2d_keypoints:
InputType.IMAGE,
Tasks.face_detection:
InputType.IMAGE,
Tasks.facial_expression_recognition:
InputType.IMAGE,
Tasks.face_recognition:
InputType.IMAGE,
Tasks.human_detection:
InputType.IMAGE,
Tasks.face_image_generation:
InputType.INT,
Tasks.image_classification:
InputType.IMAGE,
Tasks.image_object_detection:
InputType.IMAGE,
Tasks.image_segmentation:
InputType.IMAGE,
Tasks.portrait_matting:
InputType.IMAGE,

# image editing task result for a single image
Tasks.skin_retouching:
InputType.IMAGE,
Tasks.image_super_resolution:
InputType.IMAGE,
Tasks.image_colorization:
InputType.IMAGE,
Tasks.image_color_enhancement:
InputType.IMAGE,
Tasks.image_denoising:
InputType.IMAGE,
Tasks.image_portrait_enhancement:
InputType.IMAGE,
Tasks.crowd_counting:
InputType.IMAGE,

# image generation task result for a single image
Tasks.image_to_image_generation:
InputType.IMAGE,
Tasks.image_to_image_translation:
InputType.IMAGE,
Tasks.image_style_transfer:
InputType.IMAGE,
Tasks.image_portrait_stylization:
InputType.IMAGE,
Tasks.live_category:
InputType.VIDEO,
Tasks.action_recognition:
InputType.VIDEO,
Tasks.body_2d_keypoints:
InputType.IMAGE,
Tasks.body_3d_keypoints:
InputType.VIDEO,
Tasks.hand_2d_keypoints:
InputType.IMAGE,
Tasks.video_single_object_tracking: (InputType.VIDEO, InputType.BOX),
Tasks.video_category:
InputType.VIDEO,
Tasks.product_retrieval_embedding:
InputType.IMAGE,
Tasks.video_embedding:
InputType.VIDEO,
Tasks.virtual_try_on: (InputType.IMAGE, InputType.IMAGE, InputType.IMAGE),
Tasks.text_driven_segmentation: {
InputKeys.IMAGE: InputType.IMAGE,
InputKeys.TEXT: InputType.TEXT
},
Tasks.shop_segmentation:
InputType.IMAGE,
Tasks.movie_scene_segmentation:
InputType.VIDEO,

# ============ nlp tasks ===================
Tasks.text_classification: [
InputType.TEXT,
(InputType.TEXT, InputType.TEXT),
{
'text': InputType.TEXT,
'text2': InputType.TEXT
},
],
Tasks.sentence_similarity: (InputType.TEXT, InputType.TEXT),
Tasks.nli: (InputType.TEXT, InputType.TEXT),
Tasks.sentiment_classification:
InputType.TEXT,
Tasks.zero_shot_classification:
InputType.TEXT,
Tasks.relation_extraction:
InputType.TEXT,
Tasks.translation:
InputType.TEXT,
Tasks.word_segmentation:
InputType.TEXT,
Tasks.part_of_speech:
InputType.TEXT,
Tasks.named_entity_recognition:
InputType.TEXT,
Tasks.text_error_correction:
InputType.TEXT,
Tasks.sentence_embedding: {
'source_sentence': InputType.LIST,
'sentences_to_compare': InputType.LIST,
},
Tasks.passage_ranking: (InputType.TEXT, InputType.TEXT),
Tasks.text_generation:
InputType.TEXT,
Tasks.fill_mask:
InputType.TEXT,
Tasks.task_oriented_conversation: {
'user_input': InputType.TEXT,
'history': InputType.DICT,
},
Tasks.table_question_answering: {
'question': InputType.TEXT,
'history_sql': InputType.DICT,
},
Tasks.faq_question_answering: {
'query_set': InputType.LIST,
'support_set': InputType.LIST,
},

# ============ audio tasks ===================
Tasks.auto_speech_recognition:
InputType.AUDIO,
Tasks.speech_signal_process:
InputType.AUDIO,
Tasks.acoustic_echo_cancellation: {
'nearend_mic': InputType.AUDIO,
'farend_speech': InputType.AUDIO
},
Tasks.acoustic_noise_suppression:
InputType.AUDIO,
Tasks.text_to_speech:
InputType.TEXT,
Tasks.keyword_spotting:
InputType.AUDIO,

# ============ multi-modal tasks ===================
Tasks.image_captioning:
InputType.IMAGE,
Tasks.visual_grounding: {
'image': InputType.IMAGE,
'text': InputType.TEXT
},
Tasks.text_to_image_synthesis: {
'text': InputType.TEXT,
},
Tasks.multi_modal_embedding: {
'img': InputType.IMAGE,
'text': InputType.TEXT
},
Tasks.generative_multi_modal_embedding: {
'image': InputType.IMAGE,
'text': InputType.TEXT
},
Tasks.multi_modal_similarity: {
'img': InputType.IMAGE,
'text': InputType.TEXT
},
Tasks.visual_question_answering: {
'image': InputType.IMAGE,
'text': InputType.TEXT
},
Tasks.visual_entailment: {
'image': InputType.IMAGE,
'text': InputType.TEXT,
'text2': InputType.TEXT,
},
Tasks.action_detection:
InputType.VIDEO,
Tasks.image_reid_person:
InputType.IMAGE,
Tasks.video_inpainting: {
'video_input_path': InputType.TEXT,
'video_output_path': InputType.TEXT,
'mask_path': InputType.TEXT,
}
}

+ 38
- 1
modelscope/pipelines/base.py View File

@@ -13,6 +13,7 @@ import numpy as np
from modelscope.models.base import Model
from modelscope.msdatasets import MsDataset
from modelscope.outputs import TASK_OUTPUTS
from modelscope.pipeline_inputs import TASK_INPUTS, check_input_type
from modelscope.preprocessors import Preprocessor
from modelscope.utils.config import Config
from modelscope.utils.constant import Frameworks, ModelFile
@@ -210,7 +211,7 @@ class Pipeline(ABC):
preprocess_params = kwargs.get('preprocess_params', {})
forward_params = kwargs.get('forward_params', {})
postprocess_params = kwargs.get('postprocess_params', {})
self._check_input(input)
out = self.preprocess(input, **preprocess_params)
with device_placement(self.framework, self.device_name):
if self.framework == Frameworks.torch:
@@ -225,6 +226,42 @@ class Pipeline(ABC):
self._check_output(out)
return out

def _check_input(self, input):
task_name = self.group_key
if task_name in TASK_INPUTS:
input_type = TASK_INPUTS[task_name]

# if multiple input formats are defined, we first
# found the one that match input data and check
if isinstance(input_type, list):
matched_type = None
for t in input_type:
if type(t) == type(input):
matched_type = t
break
if matched_type is None:
err_msg = 'input data format for current pipeline should be one of following: \n'
for t in input_type:
err_msg += f'{t}\n'
raise ValueError(err_msg)
else:
input_type = matched_type

if isinstance(input_type, str):
check_input_type(input_type, input)
elif isinstance(input_type, tuple):
for t, input_ele in zip(input_type, input):
check_input_type(t, input_ele)
elif isinstance(input_type, dict):
for k in input_type.keys():
# allow single input for multi-modal models
if k in input:
check_input_type(input_type[k], input[k])
else:
raise ValueError(f'invalid input_type definition {input_type}')
else:
logger.warning(f'task {task_name} input definition is missing')

def _check_output(self, input):
# this attribute is dynamically attached by registry
# when cls is registered in registry using task name


+ 10
- 11
modelscope/pipelines/cv/body_3d_keypoints_pipeline.py View File

@@ -132,12 +132,7 @@ class Body3DKeypointsPipeline(Pipeline):
device='gpu' if torch.cuda.is_available() else 'cpu')

def preprocess(self, input: Input) -> Dict[str, Any]:
video_url = input.get('input_video')
self.output_video_path = input.get('output_video_path')
if self.output_video_path is None:
self.output_video_path = tempfile.NamedTemporaryFile(
suffix='.mp4').name

video_url = input
video_frames = self.read_video_frames(video_url)
if 0 == len(video_frames):
res = {'success': False, 'msg': 'get video frame failed.'}
@@ -194,9 +189,13 @@ class Body3DKeypointsPipeline(Pipeline):
pred_3d_pose = poses.data.cpu().numpy()[
0] # [frame_num, joint_num, joint_dim]

output_video_path = kwargs.get('output_video', None)
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(
suffix='.mp4').name
if 'render' in self.keypoint_model_3d.cfg.keys():
self.render_prediction(pred_3d_pose)
res[OutputKeys.OUTPUT_VIDEO] = self.output_video_path
self.render_prediction(pred_3d_pose, output_video_path)
res[OutputKeys.OUTPUT_VIDEO] = output_video_path

res[OutputKeys.POSES] = pred_3d_pose
res[OutputKeys.TIMESTAMPS] = self.timestamps
@@ -252,12 +251,12 @@ class Body3DKeypointsPipeline(Pipeline):
cap.release()
return frames

def render_prediction(self, pose3d_cam_rr):
def render_prediction(self, pose3d_cam_rr, output_video_path):
"""render predict result 3d poses.

Args:
pose3d_cam_rr (nd.array): [frame_num, joint_num, joint_dim], 3d pose joints
output_video_path (str): output path for video
Returns:
"""
frame_num = pose3d_cam_rr.shape[0]
@@ -359,4 +358,4 @@ class Body3DKeypointsPipeline(Pipeline):
# save mp4
Writer = writers['ffmpeg']
writer = Writer(fps=self.fps, metadata={}, bitrate=4096)
ani.save(self.output_video_path, writer=writer)
ani.save(output_video_path, writer=writer)

+ 3
- 9
tests/pipelines/test_body_3d_keypoints.py View File

@@ -20,7 +20,7 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck):
self.task = Tasks.body_3d_keypoints

def pipeline_inference(self, pipeline: Pipeline, pipeline_input):
output = pipeline(pipeline_input)
output = pipeline(pipeline_input, output_video='./result.mp4')
poses = np.array(output[OutputKeys.POSES])
print(f'result 3d points shape {poses.shape}')

@@ -28,10 +28,7 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck):
def test_run_modelhub_with_video_file(self):
body_3d_keypoints = pipeline(
Tasks.body_3d_keypoints, model=self.model_id)
pipeline_input = {
'input_video': self.test_video,
'output_video_path': './result.mp4'
}
pipeline_input = self.test_video
self.pipeline_inference(
body_3d_keypoints, pipeline_input=pipeline_input)

@@ -42,10 +39,7 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck):
if not cap.isOpened():
raise Exception('modelscope error: %s cannot be decoded by OpenCV.'
% (self.test_video))
pipeline_input = {
'input_video': cap,
'output_video_path': './result.mp4'
}
pipeline_input = self.test_video
self.pipeline_inference(
body_3d_keypoints, pipeline_input=pipeline_input)



+ 2
- 2
tests/pipelines/test_mplug_tasks.py View File

@@ -26,7 +26,7 @@ class MplugTasksTest(unittest.TestCase, DemoCompatibilityCheck):
model=model,
)
image = Image.open('data/test/images/image_mplug_vqa.jpg')
result = pipeline_caption({'image': image})
result = pipeline_caption(image)
print(result[OutputKeys.CAPTION])

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@@ -35,7 +35,7 @@ class MplugTasksTest(unittest.TestCase, DemoCompatibilityCheck):
Tasks.image_captioning,
model='damo/mplug_image-captioning_coco_base_en')
image = Image.open('data/test/images/image_mplug_vqa.jpg')
result = pipeline_caption({'image': image})
result = pipeline_caption(image)
print(result[OutputKeys.CAPTION])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')


+ 7
- 11
tests/pipelines/test_ofa_tasks.py View File

@@ -34,7 +34,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):
model=model,
)
image = 'data/test/images/image_captioning.png'
result = img_captioning({'image': image})
result = img_captioning(image)
print(result[OutputKeys.CAPTION])

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@@ -42,8 +42,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):
img_captioning = pipeline(
Tasks.image_captioning,
model='damo/ofa_image-caption_coco_large_en')
result = img_captioning(
{'image': 'data/test/images/image_captioning.png'})
result = img_captioning('data/test/images/image_captioning.png')
print(result[OutputKeys.CAPTION])

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@@ -52,8 +51,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):
'damo/ofa_image-classification_imagenet_large_en')
ofa_pipe = pipeline(Tasks.image_classification, model=model)
image = 'data/test/images/image_classification.png'
input = {'image': image}
result = ofa_pipe(input)
result = ofa_pipe(image)
print(result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@@ -62,8 +60,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):
Tasks.image_classification,
model='damo/ofa_image-classification_imagenet_large_en')
image = 'data/test/images/image_classification.png'
input = {'image': image}
result = ofa_pipe(input)
result = ofa_pipe(image)
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@@ -99,8 +96,8 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):
ofa_pipe = pipeline(Tasks.text_classification, model=model)
text = 'One of our number will carry out your instructions minutely.'
text2 = 'A member of my team will execute your orders with immense precision.'
input = {'text': text, 'text2': text2}
result = ofa_pipe(input)
result = ofa_pipe((text, text2))
result = ofa_pipe({'text': text, 'text2': text2})
print(result)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@@ -110,8 +107,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck):
model='damo/ofa_text-classification_mnli_large_en')
text = 'One of our number will carry out your instructions minutely.'
text2 = 'A member of my team will execute your orders with immense precision.'
input = {'text': text, 'text2': text2}
result = ofa_pipe(input)
result = ofa_pipe((text, text2))
print(result)

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')


Loading…
Cancel
Save