diff --git a/modelscope/pipelines/cv/virtual_tryon_pipeline.py b/modelscope/pipelines/cv/virtual_tryon_pipeline.py index afd5ad1a..b779e062 100644 --- a/modelscope/pipelines/cv/virtual_tryon_pipeline.py +++ b/modelscope/pipelines/cv/virtual_tryon_pipeline.py @@ -1,7 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp -from typing import Any, Dict +from typing import Any, Dict, Union import cv2 import numpy as np @@ -71,20 +71,31 @@ class VirtualTryonPipeline(Pipeline): transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) - def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]: - if isinstance(input['masked_model'], str): - img_agnostic = load_image(input['masked_model']) - pose = load_image(input['pose']) - cloth_img = load_image(input['cloth']) - elif isinstance(input['masked_model'], PIL.Image.Image): - img_agnostic = img_agnostic.convert('RGB') - pose = pose.convert('RGB') - cloth_img = cloth_img.convert('RGB') - elif isinstance(input['masked_model'], np.ndarray): + def preprocess(self, input: Union[Dict[str, Any], + tuple]) -> Dict[str, Any]: + if isinstance(input, tuple): + index_model = 0 + index_pose = 1 + index_cloth = 2 + else: + index_model = 'masked_model' + index_pose = 'pose' + index_cloth = 'cloth' + if isinstance(input[index_model], str): + img_agnostic = load_image(input[index_model]) + pose = load_image(input[index_pose]) + cloth_img = load_image(input[index_cloth]) + elif isinstance(input[index_model], PIL.Image.Image): + img_agnostic = input[index_model].convert('RGB') + pose = input[index_pose].convert('RGB') + cloth_img = input[index_cloth].convert('RGB') + elif isinstance(input[index_model], np.ndarray): if len(input.shape) == 2: - img_agnostic = cv2.cvtColor(img_agnostic, cv2.COLOR_GRAY2BGR) - pose = cv2.cvtColor(pose, cv2.COLOR_GRAY2BGR) - cloth_img = cv2.cvtColor(cloth_img, cv2.COLOR_GRAY2BGR) + img_agnostic = cv2.cvtColor(input[index_model], + cv2.COLOR_GRAY2BGR) + pose = cv2.cvtColor(input[index_pose], cv2.COLOR_GRAY2BGR) + cloth_img = cv2.cvtColor(input[index_cloth], + cv2.COLOR_GRAY2BGR) img_agnostic = Image.fromarray( img_agnostic[:, :, ::-1].astype('uint8')).convert('RGB') pose = Image.fromarray( diff --git a/tests/pipelines/test_virtual_tryon.py b/tests/pipelines/test_virtual_tryon.py index 324dc070..a81f27a9 100644 --- a/tests/pipelines/test_virtual_tryon.py +++ b/tests/pipelines/test_virtual_tryon.py @@ -3,6 +3,7 @@ import unittest import cv2 import numpy as np +from PIL import Image from modelscope.outputs import OutputKeys from modelscope.pipelines import pipeline @@ -12,11 +13,10 @@ from modelscope.utils.test_utils import test_level class VirtualTryonTest(unittest.TestCase): model_id = 'damo/cv_daflow_virtual-tryon_base' - input_imgs = { - 'masked_model': 'data/test/images/virtual_tryon_model.jpg', - 'pose': 'data/test/images/virtual_tryon_pose.jpg', - 'cloth': 'data/test/images/virtual_tryon_cloth.jpg' - } + masked_model = Image.open('data/test/images/virtual_tryon_model.jpg') + pose = Image.open('data/test/images/virtual_tryon_pose.jpg') + cloth = Image.open('data/test/images/virtual_tryon_cloth.jpg') + input_imgs = (masked_model, pose, cloth) @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') def test_run_with_model_name(self):