Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9501913 * [to #43259593] refacor image preprocessmaster
| @@ -1,6 +1,6 @@ | |||
| import os.path as osp | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| @@ -11,7 +11,6 @@ from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.torch_utils import create_device | |||
| def audio_norm(x): | |||
| @@ -8,7 +8,6 @@ from typing import Any, Dict, Generator, List, Mapping, Union | |||
| import numpy as np | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.models.base import Model | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.outputs import TASK_OUTPUTS | |||
| @@ -13,7 +13,7 @@ from modelscope.models.cv.animal_recognition import resnet | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.preprocessors import LoadImage, load_image | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| @@ -79,19 +79,7 @@ class AnimalRecogPipeline(Pipeline): | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| if isinstance(input, str): | |||
| img = load_image(input) | |||
| elif isinstance(input, PIL.Image.Image): | |||
| img = input.convert('RGB') | |||
| elif isinstance(input, np.ndarray): | |||
| if len(input.shape) == 2: | |||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
| img = input[:, :, ::-1] | |||
| img = Image.fromarray(img.astype('uint8')).convert('RGB') | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| img = LoadImage.convert_to_img(input) | |||
| normalize = transforms.Normalize( | |||
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |||
| test_transforms = transforms.Compose([ | |||
| @@ -3,7 +3,6 @@ from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import tensorflow as tf | |||
| from modelscope.metainfo import Pipelines | |||
| @@ -14,7 +13,7 @@ from modelscope.models.cv.cartoon.utils import get_f5p, padTo16x, resize_size | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| @@ -65,17 +64,7 @@ class ImageCartoonPipeline(Pipeline): | |||
| return sess | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| if isinstance(input, str): | |||
| img = np.array(load_image(input)) | |||
| elif isinstance(input, PIL.Image.Image): | |||
| img = np.array(input.convert('RGB')) | |||
| elif isinstance(input, np.ndarray): | |||
| if len(input.shape) == 2: | |||
| input = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||
| img = input[:, :, ::-1] | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| img = LoadImage.convert_to_ndarray(input) | |||
| img = img.astype(np.float) | |||
| result = {'img': img} | |||
| return result | |||
| @@ -13,7 +13,7 @@ from modelscope.models.cv.image_color_enhance.image_color_enhance import \ | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input | |||
| from modelscope.preprocessors import (ImageColorEnhanceFinetunePreprocessor, | |||
| load_image) | |||
| LoadImage, load_image) | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from ..base import Pipeline | |||
| @@ -47,18 +47,7 @@ class ImageColorEnhancePipeline(Pipeline): | |||
| self._device = torch.device('cpu') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| if isinstance(input, str): | |||
| img = load_image(input) | |||
| elif isinstance(input, PIL.Image.Image): | |||
| img = input.convert('RGB') | |||
| elif isinstance(input, np.ndarray): | |||
| if len(input.shape) == 2: | |||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |||
| img = Image.fromarray(img.astype('uint8')).convert('RGB') | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| img = LoadImage.convert_to_img(input) | |||
| test_transforms = transforms.Compose([transforms.ToTensor()]) | |||
| img = test_transforms(img) | |||
| result = {'src': img.unsqueeze(0).to(self._device)} | |||
| @@ -88,7 +88,7 @@ class ImageColorizationPipeline(Pipeline): | |||
| img = input.convert('LA').convert('RGB') | |||
| elif isinstance(input, np.ndarray): | |||
| if len(input.shape) == 2: | |||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
| img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||
| img = input[:, :, ::-1] # in rgb order | |||
| img = PIL.Image.fromarray(img).convert('LA').convert('RGB') | |||
| else: | |||
| @@ -3,13 +3,12 @@ from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| @@ -47,17 +46,7 @@ class ImageMattingPipeline(Pipeline): | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| if isinstance(input, str): | |||
| img = np.array(load_image(input)) | |||
| elif isinstance(input, PIL.Image.Image): | |||
| img = np.array(input.convert('RGB')) | |||
| elif isinstance(input, np.ndarray): | |||
| if len(input.shape) == 2: | |||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
| img = input[:, :, ::-1] # in rgb order | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| img = LoadImage.convert_to_img(input) | |||
| img = img.astype(np.float) | |||
| result = {'img': img} | |||
| return result | |||
| @@ -10,7 +10,7 @@ from modelscope.models.cv.super_resolution import rrdbnet_arch | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.preprocessors import LoadImage, load_image | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| @@ -46,18 +46,7 @@ class ImageSuperResolutionPipeline(Pipeline): | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| if isinstance(input, str): | |||
| img = np.array(load_image(input)) | |||
| elif isinstance(input, PIL.Image.Image): | |||
| img = np.array(input.convert('RGB')) | |||
| elif isinstance(input, np.ndarray): | |||
| if len(input.shape) == 2: | |||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
| img = input[:, :, ::-1] # in rgb order | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| img = LoadImage.convert_to_ndarray(input) | |||
| img = torch.from_numpy(img).to(self.device).permute( | |||
| 2, 0, 1).unsqueeze(0) / 255. | |||
| result = {'img': img} | |||
| @@ -3,14 +3,13 @@ from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import tensorflow as tf | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils | |||
| @@ -112,17 +111,8 @@ class OCRDetectionPipeline(Pipeline): | |||
| model_loader.restore(sess, model_path) | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| if isinstance(input, str): | |||
| img = np.array(load_image(input)) | |||
| elif isinstance(input, PIL.Image.Image): | |||
| img = np.array(input.convert('RGB')) | |||
| elif isinstance(input, np.ndarray): | |||
| if len(input.shape) == 2: | |||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
| img = input[:, :, ::-1] # in rgb order | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| img = LoadImage.convert_to_ndarray(input) | |||
| h, w, c = img.shape | |||
| img_pad = np.zeros((max(h, w), max(h, w), 3), dtype=np.float32) | |||
| img_pad[:h, :w, :] = img | |||
| @@ -3,13 +3,12 @@ from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input, Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.preprocessors import LoadImage | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| @@ -61,35 +60,12 @@ class StyleTransferPipeline(Pipeline): | |||
| return pipeline_parameters, {}, {} | |||
| def preprocess(self, content: Input, style: Input) -> Dict[str, Any]: | |||
| if isinstance(content, str): | |||
| content = np.array(load_image(content)) | |||
| elif isinstance(content, PIL.Image.Image): | |||
| content = np.array(content.convert('RGB')) | |||
| elif isinstance(content, np.ndarray): | |||
| if len(content.shape) == 2: | |||
| content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR) | |||
| content = content[:, :, ::-1] # in rgb order | |||
| else: | |||
| raise TypeError( | |||
| f'modelscope error: content should be either str, PIL.Image,' | |||
| f' np.array, but got {type(content)}') | |||
| content = LoadImage.convert_to_ndarray(content) | |||
| if len(content.shape) == 2: | |||
| content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR) | |||
| content_img = content.astype(np.float) | |||
| if isinstance(style, str): | |||
| style_img = np.array(load_image(style)) | |||
| elif isinstance(style, PIL.Image.Image): | |||
| style_img = np.array(style.convert('RGB')) | |||
| elif isinstance(style, np.ndarray): | |||
| if len(style.shape) == 2: | |||
| style_img = cv2.cvtColor(style, cv2.COLOR_GRAY2BGR) | |||
| style_img = style_img[:, :, ::-1] # in rgb order | |||
| else: | |||
| raise TypeError( | |||
| f'modelscope error: style should be either str, PIL.Image,' | |||
| f' np.array, but got {type(style)}') | |||
| style_img = LoadImage.convert_to_ndarray(style) | |||
| if len(style_img.shape) == 2: | |||
| style_img = cv2.cvtColor(style_img, cv2.COLOR_GRAY2BGR) | |||
| style_img = style_img.astype(np.float) | |||
| @@ -1,21 +1,18 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os.path as osp | |||
| from abc import ABC, abstractmethod | |||
| from typing import Any, Dict, Generator, List, Union | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from PIL import Image | |||
| from torchvision import transforms | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.virual_tryon.sdafnet import SDAFNet_Tryon | |||
| from modelscope.outputs import TASK_OUTPUTS, OutputKeys | |||
| from modelscope.pipelines.util import is_model, is_official_hub_path | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from ..base import Pipeline | |||
| @@ -67,6 +64,7 @@ class VirtualTryonPipeline(Pipeline): | |||
| load_pretrained(self.model, src_params) | |||
| self.model = self.model.eval() | |||
| self.size = 192 | |||
| from torchvision import transforms | |||
| self.test_transforms = transforms.Compose([ | |||
| transforms.Resize(self.size, interpolation=2), | |||
| transforms.ToTensor(), | |||
| @@ -2,7 +2,10 @@ | |||
| import io | |||
| from typing import Any, Dict, Union | |||
| import torch | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| from numpy import ndarray | |||
| from PIL import Image, ImageOps | |||
| from modelscope.fileio import File | |||
| @@ -60,6 +63,37 @@ class LoadImage: | |||
| repr_str = f'{self.__class__.__name__}(' f'mode={self.mode})' | |||
| return repr_str | |||
| @staticmethod | |||
| def convert_to_ndarray(input) -> ndarray: | |||
| if isinstance(input, str): | |||
| img = np.array(load_image(input)) | |||
| elif isinstance(input, PIL.Image.Image): | |||
| img = np.array(input.convert('RGB')) | |||
| elif isinstance(input, np.ndarray): | |||
| if len(input.shape) == 2: | |||
| input = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||
| img = input[:, :, ::-1] | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| return img | |||
| @staticmethod | |||
| def convert_to_img(input) -> ndarray: | |||
| if isinstance(input, str): | |||
| img = load_image(input) | |||
| elif isinstance(input, PIL.Image.Image): | |||
| img = input.convert('RGB') | |||
| elif isinstance(input, np.ndarray): | |||
| if len(input.shape) == 2: | |||
| img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||
| img = input[:, :, ::-1] | |||
| img = Image.fromarray(img.astype('uint8')).convert('RGB') | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| return img | |||
| def load_image(image_path_or_url: str) -> Image.Image: | |||
| """ simple interface to load an image from file or url | |||