Browse Source

refine pipeline input to support demo service

* image_captioninig support single image and dict input
* image_style_transfer use dict input

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10417330
master
wenmeng.zwm 3 years ago
parent
commit
f6e542cdcb
3 changed files with 24 additions and 13 deletions
  1. +10
    -6
      modelscope/pipeline_inputs.py
  2. +5
    -1
      modelscope/pipelines/base.py
  3. +9
    -6
      tests/pipelines/test_image_style_transfer.py

+ 10
- 6
modelscope/pipeline_inputs.py View File

@@ -97,8 +97,10 @@ TASK_INPUTS = {
InputType.IMAGE,
Tasks.image_to_image_translation:
InputType.IMAGE,
Tasks.image_style_transfer:
InputType.IMAGE,
Tasks.image_style_transfer: {
'content': InputType.IMAGE,
'style': InputType.IMAGE,
},
Tasks.image_portrait_stylization:
InputType.IMAGE,
Tasks.live_category:
@@ -147,8 +149,9 @@ TASK_INPUTS = {
InputType.TEXT,
Tasks.translation:
InputType.TEXT,
Tasks.word_segmentation:
InputType.TEXT,
Tasks.word_segmentation: [InputType.TEXT, {
'text': InputType.TEXT,
}],
Tasks.part_of_speech:
InputType.TEXT,
Tasks.named_entity_recognition:
@@ -194,8 +197,9 @@ TASK_INPUTS = {
InputType.AUDIO,

# ============ multi-modal tasks ===================
Tasks.image_captioning:
InputType.IMAGE,
Tasks.image_captioning: [InputType.IMAGE, {
'image': InputType.IMAGE,
}],
Tasks.visual_grounding: {
'image': InputType.IMAGE,
'text': InputType.TEXT


+ 5
- 1
modelscope/pipelines/base.py View File

@@ -236,7 +236,11 @@ class Pipeline(ABC):
if isinstance(input_type, list):
matched_type = None
for t in input_type:
if type(t) == type(input):
if isinstance(input, (dict, tuple)):
if type(t) == type(input):
matched_type = t
break
elif isinstance(t, str):
matched_type = t
break
if matched_type is None:


+ 9
- 6
tests/pipelines/test_image_style_transfer.py View File

@@ -25,8 +25,9 @@ class ImageStyleTransferTest(unittest.TestCase, DemoCompatibilityCheck):
Tasks.image_style_transfer, model=snapshot_path)

result = image_style_transfer(
'data/test/images/style_transfer_content.jpg',
style='data/test/images/style_transfer_style.jpg')
dict(
content='data/test/images/style_transfer_content.jpg',
style='data/test/images/style_transfer_style.jpg'))
cv2.imwrite('result_styletransfer1.png', result[OutputKeys.OUTPUT_IMG])

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@@ -35,8 +36,9 @@ class ImageStyleTransferTest(unittest.TestCase, DemoCompatibilityCheck):
Tasks.image_style_transfer, model=self.model_id)

result = image_style_transfer(
'data/test/images/style_transfer_content.jpg',
style='data/test/images/style_transfer_style.jpg')
dict(
content='data/test/images/style_transfer_content.jpg',
style='data/test/images/style_transfer_style.jpg'))
cv2.imwrite('result_styletransfer2.png', result[OutputKeys.OUTPUT_IMG])
print('style_transfer.test_run_modelhub done')

@@ -45,8 +47,9 @@ class ImageStyleTransferTest(unittest.TestCase, DemoCompatibilityCheck):
image_style_transfer = pipeline(Tasks.image_style_transfer)

result = image_style_transfer(
'data/test/images/style_transfer_content.jpg',
style='data/test/images/style_transfer_style.jpg')
dict(
content='data/test/images/style_transfer_content.jpg',
style='data/test/images/style_transfer_style.jpg'))
cv2.imwrite('result_styletransfer3.png', result[OutputKeys.OUTPUT_IMG])
print('style_transfer.test_run_modelhub_default_model done')



Loading…
Cancel
Save