Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9501913 * [to #43259593] refacor image preprocessmaster
| @@ -1,6 +1,6 @@ | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | from typing import Any, Dict | ||||
| import cv2 | |||||
| import json | import json | ||||
| import numpy as np | import numpy as np | ||||
| import torch | import torch | ||||
| @@ -11,7 +11,6 @@ from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | from modelscope.pipelines.base import Input, Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.torch_utils import create_device | |||||
| def audio_norm(x): | def audio_norm(x): | ||||
| @@ -8,7 +8,6 @@ from typing import Any, Dict, Generator, List, Mapping, Union | |||||
| import numpy as np | import numpy as np | ||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
| from modelscope.msdatasets import MsDataset | from modelscope.msdatasets import MsDataset | ||||
| from modelscope.outputs import TASK_OUTPUTS | from modelscope.outputs import TASK_OUTPUTS | ||||
| @@ -13,7 +13,7 @@ from modelscope.models.cv.animal_recognition import resnet | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Input, Pipeline | from modelscope.pipelines.base import Input, Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import load_image | |||||
| from modelscope.preprocessors import LoadImage, load_image | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -79,19 +79,7 @@ class AnimalRecogPipeline(Pipeline): | |||||
| logger.info('load model done') | logger.info('load model done') | ||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | def preprocess(self, input: Input) -> Dict[str, Any]: | ||||
| if isinstance(input, str): | |||||
| img = load_image(input) | |||||
| elif isinstance(input, PIL.Image.Image): | |||||
| img = input.convert('RGB') | |||||
| elif isinstance(input, np.ndarray): | |||||
| if len(input.shape) == 2: | |||||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||||
| img = input[:, :, ::-1] | |||||
| img = Image.fromarray(img.astype('uint8')).convert('RGB') | |||||
| else: | |||||
| raise TypeError(f'input should be either str, PIL.Image,' | |||||
| f' np.array, but got {type(input)}') | |||||
| img = LoadImage.convert_to_img(input) | |||||
| normalize = transforms.Normalize( | normalize = transforms.Normalize( | ||||
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||||
| test_transforms = transforms.Compose([ | test_transforms = transforms.Compose([ | ||||
| @@ -3,7 +3,6 @@ from typing import Any, Dict | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| import PIL | |||||
| import tensorflow as tf | import tensorflow as tf | ||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| @@ -14,7 +13,7 @@ from modelscope.models.cv.cartoon.utils import get_f5p, padTo16x, resize_size | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Input, Pipeline | from modelscope.pipelines.base import Input, Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import load_image | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.constant import Tasks | from modelscope.utils.constant import Tasks | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -65,17 +64,7 @@ class ImageCartoonPipeline(Pipeline): | |||||
| return sess | return sess | ||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | def preprocess(self, input: Input) -> Dict[str, Any]: | ||||
| if isinstance(input, str): | |||||
| img = np.array(load_image(input)) | |||||
| elif isinstance(input, PIL.Image.Image): | |||||
| img = np.array(input.convert('RGB')) | |||||
| elif isinstance(input, np.ndarray): | |||||
| if len(input.shape) == 2: | |||||
| input = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||||
| img = input[:, :, ::-1] | |||||
| else: | |||||
| raise TypeError(f'input should be either str, PIL.Image,' | |||||
| f' np.array, but got {type(input)}') | |||||
| img = LoadImage.convert_to_ndarray(input) | |||||
| img = img.astype(np.float) | img = img.astype(np.float) | ||||
| result = {'img': img} | result = {'img': img} | ||||
| return result | return result | ||||
| @@ -13,7 +13,7 @@ from modelscope.models.cv.image_color_enhance.image_color_enhance import \ | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Input | from modelscope.pipelines.base import Input | ||||
| from modelscope.preprocessors import (ImageColorEnhanceFinetunePreprocessor, | from modelscope.preprocessors import (ImageColorEnhanceFinetunePreprocessor, | ||||
| load_image) | |||||
| LoadImage, load_image) | |||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from ..base import Pipeline | from ..base import Pipeline | ||||
| @@ -47,18 +47,7 @@ class ImageColorEnhancePipeline(Pipeline): | |||||
| self._device = torch.device('cpu') | self._device = torch.device('cpu') | ||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | def preprocess(self, input: Input) -> Dict[str, Any]: | ||||
| if isinstance(input, str): | |||||
| img = load_image(input) | |||||
| elif isinstance(input, PIL.Image.Image): | |||||
| img = input.convert('RGB') | |||||
| elif isinstance(input, np.ndarray): | |||||
| if len(input.shape) == 2: | |||||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) | |||||
| img = Image.fromarray(img.astype('uint8')).convert('RGB') | |||||
| else: | |||||
| raise TypeError(f'input should be either str, PIL.Image,' | |||||
| f' np.array, but got {type(input)}') | |||||
| img = LoadImage.convert_to_img(input) | |||||
| test_transforms = transforms.Compose([transforms.ToTensor()]) | test_transforms = transforms.Compose([transforms.ToTensor()]) | ||||
| img = test_transforms(img) | img = test_transforms(img) | ||||
| result = {'src': img.unsqueeze(0).to(self._device)} | result = {'src': img.unsqueeze(0).to(self._device)} | ||||
| @@ -88,7 +88,7 @@ class ImageColorizationPipeline(Pipeline): | |||||
| img = input.convert('LA').convert('RGB') | img = input.convert('LA').convert('RGB') | ||||
| elif isinstance(input, np.ndarray): | elif isinstance(input, np.ndarray): | ||||
| if len(input.shape) == 2: | if len(input.shape) == 2: | ||||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||||
| img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||||
| img = input[:, :, ::-1] # in rgb order | img = input[:, :, ::-1] # in rgb order | ||||
| img = PIL.Image.fromarray(img).convert('LA').convert('RGB') | img = PIL.Image.fromarray(img).convert('LA').convert('RGB') | ||||
| else: | else: | ||||
| @@ -3,13 +3,12 @@ from typing import Any, Dict | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| import PIL | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Input, Pipeline | from modelscope.pipelines.base import Input, Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import load_image | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -47,17 +46,7 @@ class ImageMattingPipeline(Pipeline): | |||||
| logger.info('load model done') | logger.info('load model done') | ||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | def preprocess(self, input: Input) -> Dict[str, Any]: | ||||
| if isinstance(input, str): | |||||
| img = np.array(load_image(input)) | |||||
| elif isinstance(input, PIL.Image.Image): | |||||
| img = np.array(input.convert('RGB')) | |||||
| elif isinstance(input, np.ndarray): | |||||
| if len(input.shape) == 2: | |||||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||||
| img = input[:, :, ::-1] # in rgb order | |||||
| else: | |||||
| raise TypeError(f'input should be either str, PIL.Image,' | |||||
| f' np.array, but got {type(input)}') | |||||
| img = LoadImage.convert_to_img(input) | |||||
| img = img.astype(np.float) | img = img.astype(np.float) | ||||
| result = {'img': img} | result = {'img': img} | ||||
| return result | return result | ||||
| @@ -10,7 +10,7 @@ from modelscope.models.cv.super_resolution import rrdbnet_arch | |||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Input, Pipeline | from modelscope.pipelines.base import Input, Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import load_image | |||||
| from modelscope.preprocessors import LoadImage, load_image | |||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -46,18 +46,7 @@ class ImageSuperResolutionPipeline(Pipeline): | |||||
| logger.info('load model done') | logger.info('load model done') | ||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | def preprocess(self, input: Input) -> Dict[str, Any]: | ||||
| if isinstance(input, str): | |||||
| img = np.array(load_image(input)) | |||||
| elif isinstance(input, PIL.Image.Image): | |||||
| img = np.array(input.convert('RGB')) | |||||
| elif isinstance(input, np.ndarray): | |||||
| if len(input.shape) == 2: | |||||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||||
| img = input[:, :, ::-1] # in rgb order | |||||
| else: | |||||
| raise TypeError(f'input should be either str, PIL.Image,' | |||||
| f' np.array, but got {type(input)}') | |||||
| img = LoadImage.convert_to_ndarray(input) | |||||
| img = torch.from_numpy(img).to(self.device).permute( | img = torch.from_numpy(img).to(self.device).permute( | ||||
| 2, 0, 1).unsqueeze(0) / 255. | 2, 0, 1).unsqueeze(0) / 255. | ||||
| result = {'img': img} | result = {'img': img} | ||||
| @@ -3,14 +3,13 @@ from typing import Any, Dict | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| import PIL | |||||
| import tensorflow as tf | import tensorflow as tf | ||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Input, Pipeline | from modelscope.pipelines.base import Input, Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import load_image | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils | from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils | ||||
| @@ -112,17 +111,8 @@ class OCRDetectionPipeline(Pipeline): | |||||
| model_loader.restore(sess, model_path) | model_loader.restore(sess, model_path) | ||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | def preprocess(self, input: Input) -> Dict[str, Any]: | ||||
| if isinstance(input, str): | |||||
| img = np.array(load_image(input)) | |||||
| elif isinstance(input, PIL.Image.Image): | |||||
| img = np.array(input.convert('RGB')) | |||||
| elif isinstance(input, np.ndarray): | |||||
| if len(input.shape) == 2: | |||||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||||
| img = input[:, :, ::-1] # in rgb order | |||||
| else: | |||||
| raise TypeError(f'input should be either str, PIL.Image,' | |||||
| f' np.array, but got {type(input)}') | |||||
| img = LoadImage.convert_to_ndarray(input) | |||||
| h, w, c = img.shape | h, w, c = img.shape | ||||
| img_pad = np.zeros((max(h, w), max(h, w), 3), dtype=np.float32) | img_pad = np.zeros((max(h, w), max(h, w), 3), dtype=np.float32) | ||||
| img_pad[:h, :w, :] = img | img_pad[:h, :w, :] = img | ||||
| @@ -3,13 +3,12 @@ from typing import Any, Dict | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| import PIL | |||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.outputs import OutputKeys | from modelscope.outputs import OutputKeys | ||||
| from modelscope.pipelines.base import Input, Pipeline | from modelscope.pipelines.base import Input, Pipeline | ||||
| from modelscope.pipelines.builder import PIPELINES | from modelscope.pipelines.builder import PIPELINES | ||||
| from modelscope.preprocessors import load_image | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from modelscope.utils.logger import get_logger | from modelscope.utils.logger import get_logger | ||||
| @@ -61,35 +60,12 @@ class StyleTransferPipeline(Pipeline): | |||||
| return pipeline_parameters, {}, {} | return pipeline_parameters, {}, {} | ||||
| def preprocess(self, content: Input, style: Input) -> Dict[str, Any]: | def preprocess(self, content: Input, style: Input) -> Dict[str, Any]: | ||||
| if isinstance(content, str): | |||||
| content = np.array(load_image(content)) | |||||
| elif isinstance(content, PIL.Image.Image): | |||||
| content = np.array(content.convert('RGB')) | |||||
| elif isinstance(content, np.ndarray): | |||||
| if len(content.shape) == 2: | |||||
| content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR) | |||||
| content = content[:, :, ::-1] # in rgb order | |||||
| else: | |||||
| raise TypeError( | |||||
| f'modelscope error: content should be either str, PIL.Image,' | |||||
| f' np.array, but got {type(content)}') | |||||
| content = LoadImage.convert_to_ndarray(content) | |||||
| if len(content.shape) == 2: | if len(content.shape) == 2: | ||||
| content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR) | content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR) | ||||
| content_img = content.astype(np.float) | content_img = content.astype(np.float) | ||||
| if isinstance(style, str): | |||||
| style_img = np.array(load_image(style)) | |||||
| elif isinstance(style, PIL.Image.Image): | |||||
| style_img = np.array(style.convert('RGB')) | |||||
| elif isinstance(style, np.ndarray): | |||||
| if len(style.shape) == 2: | |||||
| style_img = cv2.cvtColor(style, cv2.COLOR_GRAY2BGR) | |||||
| style_img = style_img[:, :, ::-1] # in rgb order | |||||
| else: | |||||
| raise TypeError( | |||||
| f'modelscope error: style should be either str, PIL.Image,' | |||||
| f' np.array, but got {type(style)}') | |||||
| style_img = LoadImage.convert_to_ndarray(style) | |||||
| if len(style_img.shape) == 2: | if len(style_img.shape) == 2: | ||||
| style_img = cv2.cvtColor(style_img, cv2.COLOR_GRAY2BGR) | style_img = cv2.cvtColor(style_img, cv2.COLOR_GRAY2BGR) | ||||
| style_img = style_img.astype(np.float) | style_img = style_img.astype(np.float) | ||||
| @@ -1,21 +1,18 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | # Copyright (c) Alibaba, Inc. and its affiliates. | ||||
| import os.path as osp | import os.path as osp | ||||
| from abc import ABC, abstractmethod | |||||
| from typing import Any, Dict, Generator, List, Union | |||||
| from typing import Any, Dict | |||||
| import cv2 | import cv2 | ||||
| import numpy as np | import numpy as np | ||||
| import PIL | import PIL | ||||
| import torch | import torch | ||||
| from PIL import Image | from PIL import Image | ||||
| from torchvision import transforms | |||||
| from modelscope.hub.snapshot_download import snapshot_download | from modelscope.hub.snapshot_download import snapshot_download | ||||
| from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
| from modelscope.models.cv.virual_tryon.sdafnet import SDAFNet_Tryon | from modelscope.models.cv.virual_tryon.sdafnet import SDAFNet_Tryon | ||||
| from modelscope.outputs import TASK_OUTPUTS, OutputKeys | |||||
| from modelscope.pipelines.util import is_model, is_official_hub_path | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.preprocessors import load_image | from modelscope.preprocessors import load_image | ||||
| from modelscope.utils.constant import ModelFile, Tasks | from modelscope.utils.constant import ModelFile, Tasks | ||||
| from ..base import Pipeline | from ..base import Pipeline | ||||
| @@ -67,6 +64,7 @@ class VirtualTryonPipeline(Pipeline): | |||||
| load_pretrained(self.model, src_params) | load_pretrained(self.model, src_params) | ||||
| self.model = self.model.eval() | self.model = self.model.eval() | ||||
| self.size = 192 | self.size = 192 | ||||
| from torchvision import transforms | |||||
| self.test_transforms = transforms.Compose([ | self.test_transforms = transforms.Compose([ | ||||
| transforms.Resize(self.size, interpolation=2), | transforms.Resize(self.size, interpolation=2), | ||||
| transforms.ToTensor(), | transforms.ToTensor(), | ||||
| @@ -2,7 +2,10 @@ | |||||
| import io | import io | ||||
| from typing import Any, Dict, Union | from typing import Any, Dict, Union | ||||
| import torch | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import PIL | |||||
| from numpy import ndarray | |||||
| from PIL import Image, ImageOps | from PIL import Image, ImageOps | ||||
| from modelscope.fileio import File | from modelscope.fileio import File | ||||
| @@ -60,6 +63,37 @@ class LoadImage: | |||||
| repr_str = f'{self.__class__.__name__}(' f'mode={self.mode})' | repr_str = f'{self.__class__.__name__}(' f'mode={self.mode})' | ||||
| return repr_str | return repr_str | ||||
| @staticmethod | |||||
| def convert_to_ndarray(input) -> ndarray: | |||||
| if isinstance(input, str): | |||||
| img = np.array(load_image(input)) | |||||
| elif isinstance(input, PIL.Image.Image): | |||||
| img = np.array(input.convert('RGB')) | |||||
| elif isinstance(input, np.ndarray): | |||||
| if len(input.shape) == 2: | |||||
| input = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||||
| img = input[:, :, ::-1] | |||||
| else: | |||||
| raise TypeError(f'input should be either str, PIL.Image,' | |||||
| f' np.array, but got {type(input)}') | |||||
| return img | |||||
| @staticmethod | |||||
| def convert_to_img(input) -> ndarray: | |||||
| if isinstance(input, str): | |||||
| img = load_image(input) | |||||
| elif isinstance(input, PIL.Image.Image): | |||||
| img = input.convert('RGB') | |||||
| elif isinstance(input, np.ndarray): | |||||
| if len(input.shape) == 2: | |||||
| img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) | |||||
| img = input[:, :, ::-1] | |||||
| img = Image.fromarray(img.astype('uint8')).convert('RGB') | |||||
| else: | |||||
| raise TypeError(f'input should be either str, PIL.Image,' | |||||
| f' np.array, but got {type(input)}') | |||||
| return img | |||||
| def load_image(image_path_or_url: str) -> Image.Image: | def load_image(image_path_or_url: str) -> Image.Image: | ||||
| """ simple interface to load an image from file or url | """ simple interface to load an image from file or url | ||||