From 65bea053afe17373c1498cc466b59e9514a5febc Mon Sep 17 00:00:00 2001 From: "wenmeng.zwm" Date: Sun, 9 Oct 2022 18:59:07 +0800 Subject: [PATCH] [to #44902165] feat: Add input signature for pipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10233235 * add input definition for pipeline * support multiple input format for single pipeline * change param for body_3d_keypoints --- modelscope/pipeline_inputs.py | 236 ++++++++++++++++++ modelscope/pipelines/base.py | 39 ++- .../cv/body_3d_keypoints_pipeline.py | 21 +- tests/pipelines/test_body_3d_keypoints.py | 12 +- tests/pipelines/test_mplug_tasks.py | 4 +- tests/pipelines/test_ofa_tasks.py | 18 +- 6 files changed, 296 insertions(+), 34 deletions(-) create mode 100644 modelscope/pipeline_inputs.py diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py new file mode 100644 index 00000000..de9814a7 --- /dev/null +++ b/modelscope/pipeline_inputs.py @@ -0,0 +1,236 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import cv2 +import numpy as np +from PIL import Image + +from modelscope.models.base.base_head import Input +from modelscope.utils.constant import Tasks + + +class InputKeys(object): + IMAGE = 'image' + TEXT = 'text' + VIDEO = 'video' + + +class InputType(object): + IMAGE = 'image' + TEXT = 'text' + AUDIO = 'audio' + VIDEO = 'video' + BOX = 'box' + DICT = 'dict' + LIST = 'list' + INT = 'int' + + +INPUT_TYPE = { + InputType.IMAGE: (str, np.ndarray, Image.Image), + InputType.TEXT: str, + InputType.AUDIO: (str, np.ndarray), + InputType.VIDEO: (str, np.ndarray, cv2.VideoCapture), + InputType.BOX: (list, np.ndarray), + InputType.DICT: (dict, type(None)), + InputType.LIST: (list, type(None)), + InputType.INT: int, +} + + +def check_input_type(input_type, input): + expected_type = INPUT_TYPE[input_type] + assert isinstance(input, expected_type), \ + f'invalid input type for {input_type}, expected {expected_type} but got {type(input)}\n {input}' + + +TASK_INPUTS = { + # if task input is single var, value is InputType + # if task input is a tuple, value is tuple of InputType + # if task input is a dict, value is a dict of InputType, where key + # equals the one needed in pipeline input dict + # if task input is a list, value is a set of input format, in which + # each elements corresponds to one input format as described above. + # ============ vision tasks =================== + Tasks.ocr_detection: + InputType.IMAGE, + Tasks.ocr_recognition: + InputType.IMAGE, + Tasks.face_2d_keypoints: + InputType.IMAGE, + Tasks.face_detection: + InputType.IMAGE, + Tasks.facial_expression_recognition: + InputType.IMAGE, + Tasks.face_recognition: + InputType.IMAGE, + Tasks.human_detection: + InputType.IMAGE, + Tasks.face_image_generation: + InputType.INT, + Tasks.image_classification: + InputType.IMAGE, + Tasks.image_object_detection: + InputType.IMAGE, + Tasks.image_segmentation: + InputType.IMAGE, + Tasks.portrait_matting: + InputType.IMAGE, + + # image editing task result for a single image + Tasks.skin_retouching: + InputType.IMAGE, + Tasks.image_super_resolution: + InputType.IMAGE, + Tasks.image_colorization: + InputType.IMAGE, + Tasks.image_color_enhancement: + InputType.IMAGE, + Tasks.image_denoising: + InputType.IMAGE, + Tasks.image_portrait_enhancement: + InputType.IMAGE, + Tasks.crowd_counting: + InputType.IMAGE, + + # image generation task result for a single image + Tasks.image_to_image_generation: + InputType.IMAGE, + Tasks.image_to_image_translation: + InputType.IMAGE, + Tasks.image_style_transfer: + InputType.IMAGE, + Tasks.image_portrait_stylization: + InputType.IMAGE, + Tasks.live_category: + InputType.VIDEO, + Tasks.action_recognition: + InputType.VIDEO, + Tasks.body_2d_keypoints: + InputType.IMAGE, + Tasks.body_3d_keypoints: + InputType.VIDEO, + Tasks.hand_2d_keypoints: + InputType.IMAGE, + Tasks.video_single_object_tracking: (InputType.VIDEO, InputType.BOX), + Tasks.video_category: + InputType.VIDEO, + Tasks.product_retrieval_embedding: + InputType.IMAGE, + Tasks.video_embedding: + InputType.VIDEO, + Tasks.virtual_try_on: (InputType.IMAGE, InputType.IMAGE, InputType.IMAGE), + Tasks.text_driven_segmentation: { + InputKeys.IMAGE: InputType.IMAGE, + InputKeys.TEXT: InputType.TEXT + }, + Tasks.shop_segmentation: + InputType.IMAGE, + Tasks.movie_scene_segmentation: + InputType.VIDEO, + + # ============ nlp tasks =================== + Tasks.text_classification: [ + InputType.TEXT, + (InputType.TEXT, InputType.TEXT), + { + 'text': InputType.TEXT, + 'text2': InputType.TEXT + }, + ], + Tasks.sentence_similarity: (InputType.TEXT, InputType.TEXT), + Tasks.nli: (InputType.TEXT, InputType.TEXT), + Tasks.sentiment_classification: + InputType.TEXT, + Tasks.zero_shot_classification: + InputType.TEXT, + Tasks.relation_extraction: + InputType.TEXT, + Tasks.translation: + InputType.TEXT, + Tasks.word_segmentation: + InputType.TEXT, + Tasks.part_of_speech: + InputType.TEXT, + Tasks.named_entity_recognition: + InputType.TEXT, + Tasks.text_error_correction: + InputType.TEXT, + Tasks.sentence_embedding: { + 'source_sentence': InputType.LIST, + 'sentences_to_compare': InputType.LIST, + }, + Tasks.passage_ranking: (InputType.TEXT, InputType.TEXT), + Tasks.text_generation: + InputType.TEXT, + Tasks.fill_mask: + InputType.TEXT, + Tasks.task_oriented_conversation: { + 'user_input': InputType.TEXT, + 'history': InputType.DICT, + }, + Tasks.table_question_answering: { + 'question': InputType.TEXT, + 'history_sql': InputType.DICT, + }, + Tasks.faq_question_answering: { + 'query_set': InputType.LIST, + 'support_set': InputType.LIST, + }, + + # ============ audio tasks =================== + Tasks.auto_speech_recognition: + InputType.AUDIO, + Tasks.speech_signal_process: + InputType.AUDIO, + Tasks.acoustic_echo_cancellation: { + 'nearend_mic': InputType.AUDIO, + 'farend_speech': InputType.AUDIO + }, + Tasks.acoustic_noise_suppression: + InputType.AUDIO, + Tasks.text_to_speech: + InputType.TEXT, + Tasks.keyword_spotting: + InputType.AUDIO, + + # ============ multi-modal tasks =================== + Tasks.image_captioning: + InputType.IMAGE, + Tasks.visual_grounding: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.text_to_image_synthesis: { + 'text': InputType.TEXT, + }, + Tasks.multi_modal_embedding: { + 'img': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.generative_multi_modal_embedding: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.multi_modal_similarity: { + 'img': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.visual_question_answering: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT + }, + Tasks.visual_entailment: { + 'image': InputType.IMAGE, + 'text': InputType.TEXT, + 'text2': InputType.TEXT, + }, + Tasks.action_detection: + InputType.VIDEO, + Tasks.image_reid_person: + InputType.IMAGE, + Tasks.video_inpainting: { + 'video_input_path': InputType.TEXT, + 'video_output_path': InputType.TEXT, + 'mask_path': InputType.TEXT, + } +} diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index c5db2b57..5732a9d7 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -13,6 +13,7 @@ import numpy as np from modelscope.models.base import Model from modelscope.msdatasets import MsDataset from modelscope.outputs import TASK_OUTPUTS +from modelscope.pipeline_inputs import TASK_INPUTS, check_input_type from modelscope.preprocessors import Preprocessor from modelscope.utils.config import Config from modelscope.utils.constant import Frameworks, ModelFile @@ -210,7 +211,7 @@ class Pipeline(ABC): preprocess_params = kwargs.get('preprocess_params', {}) forward_params = kwargs.get('forward_params', {}) postprocess_params = kwargs.get('postprocess_params', {}) - + self._check_input(input) out = self.preprocess(input, **preprocess_params) with device_placement(self.framework, self.device_name): if self.framework == Frameworks.torch: @@ -225,6 +226,42 @@ class Pipeline(ABC): self._check_output(out) return out + def _check_input(self, input): + task_name = self.group_key + if task_name in TASK_INPUTS: + input_type = TASK_INPUTS[task_name] + + # if multiple input formats are defined, we first + # found the one that match input data and check + if isinstance(input_type, list): + matched_type = None + for t in input_type: + if type(t) == type(input): + matched_type = t + break + if matched_type is None: + err_msg = 'input data format for current pipeline should be one of following: \n' + for t in input_type: + err_msg += f'{t}\n' + raise ValueError(err_msg) + else: + input_type = matched_type + + if isinstance(input_type, str): + check_input_type(input_type, input) + elif isinstance(input_type, tuple): + for t, input_ele in zip(input_type, input): + check_input_type(t, input_ele) + elif isinstance(input_type, dict): + for k in input_type.keys(): + # allow single input for multi-modal models + if k in input: + check_input_type(input_type[k], input[k]) + else: + raise ValueError(f'invalid input_type definition {input_type}') + else: + logger.warning(f'task {task_name} input definition is missing') + def _check_output(self, input): # this attribute is dynamically attached by registry # when cls is registered in registry using task name diff --git a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py index 474c0e54..b0faa1e0 100644 --- a/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py +++ b/modelscope/pipelines/cv/body_3d_keypoints_pipeline.py @@ -132,12 +132,7 @@ class Body3DKeypointsPipeline(Pipeline): device='gpu' if torch.cuda.is_available() else 'cpu') def preprocess(self, input: Input) -> Dict[str, Any]: - video_url = input.get('input_video') - self.output_video_path = input.get('output_video_path') - if self.output_video_path is None: - self.output_video_path = tempfile.NamedTemporaryFile( - suffix='.mp4').name - + video_url = input video_frames = self.read_video_frames(video_url) if 0 == len(video_frames): res = {'success': False, 'msg': 'get video frame failed.'} @@ -194,9 +189,13 @@ class Body3DKeypointsPipeline(Pipeline): pred_3d_pose = poses.data.cpu().numpy()[ 0] # [frame_num, joint_num, joint_dim] + output_video_path = kwargs.get('output_video', None) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile( + suffix='.mp4').name if 'render' in self.keypoint_model_3d.cfg.keys(): - self.render_prediction(pred_3d_pose) - res[OutputKeys.OUTPUT_VIDEO] = self.output_video_path + self.render_prediction(pred_3d_pose, output_video_path) + res[OutputKeys.OUTPUT_VIDEO] = output_video_path res[OutputKeys.POSES] = pred_3d_pose res[OutputKeys.TIMESTAMPS] = self.timestamps @@ -252,12 +251,12 @@ class Body3DKeypointsPipeline(Pipeline): cap.release() return frames - def render_prediction(self, pose3d_cam_rr): + def render_prediction(self, pose3d_cam_rr, output_video_path): """render predict result 3d poses. Args: pose3d_cam_rr (nd.array): [frame_num, joint_num, joint_dim], 3d pose joints - + output_video_path (str): output path for video Returns: """ frame_num = pose3d_cam_rr.shape[0] @@ -359,4 +358,4 @@ class Body3DKeypointsPipeline(Pipeline): # save mp4 Writer = writers['ffmpeg'] writer = Writer(fps=self.fps, metadata={}, bitrate=4096) - ani.save(self.output_video_path, writer=writer) + ani.save(output_video_path, writer=writer) diff --git a/tests/pipelines/test_body_3d_keypoints.py b/tests/pipelines/test_body_3d_keypoints.py index bde04f8e..6f27f12d 100644 --- a/tests/pipelines/test_body_3d_keypoints.py +++ b/tests/pipelines/test_body_3d_keypoints.py @@ -20,7 +20,7 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck): self.task = Tasks.body_3d_keypoints def pipeline_inference(self, pipeline: Pipeline, pipeline_input): - output = pipeline(pipeline_input) + output = pipeline(pipeline_input, output_video='./result.mp4') poses = np.array(output[OutputKeys.POSES]) print(f'result 3d points shape {poses.shape}') @@ -28,10 +28,7 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck): def test_run_modelhub_with_video_file(self): body_3d_keypoints = pipeline( Tasks.body_3d_keypoints, model=self.model_id) - pipeline_input = { - 'input_video': self.test_video, - 'output_video_path': './result.mp4' - } + pipeline_input = self.test_video self.pipeline_inference( body_3d_keypoints, pipeline_input=pipeline_input) @@ -42,10 +39,7 @@ class Body3DKeypointsTest(unittest.TestCase, DemoCompatibilityCheck): if not cap.isOpened(): raise Exception('modelscope error: %s cannot be decoded by OpenCV.' % (self.test_video)) - pipeline_input = { - 'input_video': cap, - 'output_video_path': './result.mp4' - } + pipeline_input = self.test_video self.pipeline_inference( body_3d_keypoints, pipeline_input=pipeline_input) diff --git a/tests/pipelines/test_mplug_tasks.py b/tests/pipelines/test_mplug_tasks.py index a3ace62d..11c9798f 100644 --- a/tests/pipelines/test_mplug_tasks.py +++ b/tests/pipelines/test_mplug_tasks.py @@ -26,7 +26,7 @@ class MplugTasksTest(unittest.TestCase, DemoCompatibilityCheck): model=model, ) image = Image.open('data/test/images/image_mplug_vqa.jpg') - result = pipeline_caption({'image': image}) + result = pipeline_caption(image) print(result[OutputKeys.CAPTION]) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -35,7 +35,7 @@ class MplugTasksTest(unittest.TestCase, DemoCompatibilityCheck): Tasks.image_captioning, model='damo/mplug_image-captioning_coco_base_en') image = Image.open('data/test/images/image_mplug_vqa.jpg') - result = pipeline_caption({'image': image}) + result = pipeline_caption(image) print(result[OutputKeys.CAPTION]) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') diff --git a/tests/pipelines/test_ofa_tasks.py b/tests/pipelines/test_ofa_tasks.py index e6638dfa..f8366508 100644 --- a/tests/pipelines/test_ofa_tasks.py +++ b/tests/pipelines/test_ofa_tasks.py @@ -34,7 +34,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): model=model, ) image = 'data/test/images/image_captioning.png' - result = img_captioning({'image': image}) + result = img_captioning(image) print(result[OutputKeys.CAPTION]) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -42,8 +42,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): img_captioning = pipeline( Tasks.image_captioning, model='damo/ofa_image-caption_coco_large_en') - result = img_captioning( - {'image': 'data/test/images/image_captioning.png'}) + result = img_captioning('data/test/images/image_captioning.png') print(result[OutputKeys.CAPTION]) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @@ -52,8 +51,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): 'damo/ofa_image-classification_imagenet_large_en') ofa_pipe = pipeline(Tasks.image_classification, model=model) image = 'data/test/images/image_classification.png' - input = {'image': image} - result = ofa_pipe(input) + result = ofa_pipe(image) print(result) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -62,8 +60,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): Tasks.image_classification, model='damo/ofa_image-classification_imagenet_large_en') image = 'data/test/images/image_classification.png' - input = {'image': image} - result = ofa_pipe(input) + result = ofa_pipe(image) print(result) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') @@ -99,8 +96,8 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): ofa_pipe = pipeline(Tasks.text_classification, model=model) text = 'One of our number will carry out your instructions minutely.' text2 = 'A member of my team will execute your orders with immense precision.' - input = {'text': text, 'text2': text2} - result = ofa_pipe(input) + result = ofa_pipe((text, text2)) + result = ofa_pipe({'text': text, 'text2': text2}) print(result) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -110,8 +107,7 @@ class OfaTasksTest(unittest.TestCase, DemoCompatibilityCheck): model='damo/ofa_text-classification_mnli_large_en') text = 'One of our number will carry out your instructions minutely.' text2 = 'A member of my team will execute your orders with immense precision.' - input = {'text': text, 'text2': text2} - result = ofa_pipe(input) + result = ofa_pipe((text, text2)) print(result) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')