From f6e542cdcb6c1a1be690750bebda791ed5c90589 Mon Sep 17 00:00:00 2001 From: "wenmeng.zwm" Date: Mon, 17 Oct 2022 10:40:08 +0800 Subject: [PATCH] refine pipeline input to support demo service * image_captioninig support single image and dict input * image_style_transfer use dict input Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10417330 --- modelscope/pipeline_inputs.py | 16 ++++++++++------ modelscope/pipelines/base.py | 6 +++++- tests/pipelines/test_image_style_transfer.py | 15 +++++++++------ 3 files changed, 24 insertions(+), 13 deletions(-) diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py index 2b14c278..34b731c6 100644 --- a/modelscope/pipeline_inputs.py +++ b/modelscope/pipeline_inputs.py @@ -97,8 +97,10 @@ TASK_INPUTS = { InputType.IMAGE, Tasks.image_to_image_translation: InputType.IMAGE, - Tasks.image_style_transfer: - InputType.IMAGE, + Tasks.image_style_transfer: { + 'content': InputType.IMAGE, + 'style': InputType.IMAGE, + }, Tasks.image_portrait_stylization: InputType.IMAGE, Tasks.live_category: @@ -147,8 +149,9 @@ TASK_INPUTS = { InputType.TEXT, Tasks.translation: InputType.TEXT, - Tasks.word_segmentation: - InputType.TEXT, + Tasks.word_segmentation: [InputType.TEXT, { + 'text': InputType.TEXT, + }], Tasks.part_of_speech: InputType.TEXT, Tasks.named_entity_recognition: @@ -194,8 +197,9 @@ TASK_INPUTS = { InputType.AUDIO, # ============ multi-modal tasks =================== - Tasks.image_captioning: - InputType.IMAGE, + Tasks.image_captioning: [InputType.IMAGE, { + 'image': InputType.IMAGE, + }], Tasks.visual_grounding: { 'image': InputType.IMAGE, 'text': InputType.TEXT diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 5732a9d7..ea329be4 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -236,7 +236,11 @@ class Pipeline(ABC): if isinstance(input_type, list): matched_type = None for t in input_type: - if type(t) == type(input): + if isinstance(input, (dict, tuple)): + if type(t) == type(input): + matched_type = t + break + elif isinstance(t, str): matched_type = t break if matched_type is None: diff --git a/tests/pipelines/test_image_style_transfer.py b/tests/pipelines/test_image_style_transfer.py index a02d5308..5f37f204 100644 --- a/tests/pipelines/test_image_style_transfer.py +++ b/tests/pipelines/test_image_style_transfer.py @@ -25,8 +25,9 @@ class ImageStyleTransferTest(unittest.TestCase, DemoCompatibilityCheck): Tasks.image_style_transfer, model=snapshot_path) result = image_style_transfer( - 'data/test/images/style_transfer_content.jpg', - style='data/test/images/style_transfer_style.jpg') + dict( + content='data/test/images/style_transfer_content.jpg', + style='data/test/images/style_transfer_style.jpg')) cv2.imwrite('result_styletransfer1.png', result[OutputKeys.OUTPUT_IMG]) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') @@ -35,8 +36,9 @@ class ImageStyleTransferTest(unittest.TestCase, DemoCompatibilityCheck): Tasks.image_style_transfer, model=self.model_id) result = image_style_transfer( - 'data/test/images/style_transfer_content.jpg', - style='data/test/images/style_transfer_style.jpg') + dict( + content='data/test/images/style_transfer_content.jpg', + style='data/test/images/style_transfer_style.jpg')) cv2.imwrite('result_styletransfer2.png', result[OutputKeys.OUTPUT_IMG]) print('style_transfer.test_run_modelhub done') @@ -45,8 +47,9 @@ class ImageStyleTransferTest(unittest.TestCase, DemoCompatibilityCheck): image_style_transfer = pipeline(Tasks.image_style_transfer) result = image_style_transfer( - 'data/test/images/style_transfer_content.jpg', - style='data/test/images/style_transfer_style.jpg') + dict( + content='data/test/images/style_transfer_content.jpg', + style='data/test/images/style_transfer_style.jpg')) cv2.imwrite('result_styletransfer3.png', result[OutputKeys.OUTPUT_IMG]) print('style_transfer.test_run_modelhub_default_model done')