|
|
|
@@ -1,7 +1,7 @@ |
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
|
|
|
|
|
import os.path as osp |
|
|
|
from typing import Any, Dict |
|
|
|
from typing import Any, Dict, Union |
|
|
|
|
|
|
|
import cv2 |
|
|
|
import numpy as np |
|
|
|
@@ -71,20 +71,31 @@ class VirtualTryonPipeline(Pipeline): |
|
|
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
|
|
]) |
|
|
|
|
|
|
|
def preprocess(self, input: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
if isinstance(input['masked_model'], str): |
|
|
|
img_agnostic = load_image(input['masked_model']) |
|
|
|
pose = load_image(input['pose']) |
|
|
|
cloth_img = load_image(input['cloth']) |
|
|
|
elif isinstance(input['masked_model'], PIL.Image.Image): |
|
|
|
img_agnostic = img_agnostic.convert('RGB') |
|
|
|
pose = pose.convert('RGB') |
|
|
|
cloth_img = cloth_img.convert('RGB') |
|
|
|
elif isinstance(input['masked_model'], np.ndarray): |
|
|
|
def preprocess(self, input: Union[Dict[str, Any], |
|
|
|
tuple]) -> Dict[str, Any]: |
|
|
|
if isinstance(input, tuple): |
|
|
|
index_model = 0 |
|
|
|
index_pose = 1 |
|
|
|
index_cloth = 2 |
|
|
|
else: |
|
|
|
index_model = 'masked_model' |
|
|
|
index_pose = 'pose' |
|
|
|
index_cloth = 'cloth' |
|
|
|
if isinstance(input[index_model], str): |
|
|
|
img_agnostic = load_image(input[index_model]) |
|
|
|
pose = load_image(input[index_pose]) |
|
|
|
cloth_img = load_image(input[index_cloth]) |
|
|
|
elif isinstance(input[index_model], PIL.Image.Image): |
|
|
|
img_agnostic = input[index_model].convert('RGB') |
|
|
|
pose = input[index_pose].convert('RGB') |
|
|
|
cloth_img = input[index_cloth].convert('RGB') |
|
|
|
elif isinstance(input[index_model], np.ndarray): |
|
|
|
if len(input.shape) == 2: |
|
|
|
img_agnostic = cv2.cvtColor(img_agnostic, cv2.COLOR_GRAY2BGR) |
|
|
|
pose = cv2.cvtColor(pose, cv2.COLOR_GRAY2BGR) |
|
|
|
cloth_img = cv2.cvtColor(cloth_img, cv2.COLOR_GRAY2BGR) |
|
|
|
img_agnostic = cv2.cvtColor(input[index_model], |
|
|
|
cv2.COLOR_GRAY2BGR) |
|
|
|
pose = cv2.cvtColor(input[index_pose], cv2.COLOR_GRAY2BGR) |
|
|
|
cloth_img = cv2.cvtColor(input[index_cloth], |
|
|
|
cv2.COLOR_GRAY2BGR) |
|
|
|
img_agnostic = Image.fromarray( |
|
|
|
img_agnostic[:, :, ::-1].astype('uint8')).convert('RGB') |
|
|
|
pose = Image.fromarray( |
|
|
|
|