From 590bc52f97d1806db31d4227ff445df6930de798 Mon Sep 17 00:00:00 2001 From: "yingda.chen" Date: Mon, 25 Jul 2022 17:00:21 +0800 Subject: [PATCH] [to #43259593] refacor image preprocess Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9501913 * [to #43259593] refacor image preprocess --- .../models/multi_modal/clip/clip_model.py | 2 +- modelscope/pipelines/audio/ans_pipeline.py | 1 - modelscope/pipelines/base.py | 1 - .../pipelines/cv/animal_recog_pipeline.py | 16 ++------- .../pipelines/cv/image_cartoon_pipeline.py | 15 ++------ .../cv/image_color_enhance_pipeline.py | 15 ++------ .../cv/image_colorization_pipeline.py | 2 +- .../pipelines/cv/image_matting_pipeline.py | 15 ++------ .../cv/image_super_resolution_pipeline.py | 15 ++------ .../pipelines/cv/ocr_detection_pipeline.py | 16 ++------- .../pipelines/cv/style_transfer_pipeline.py | 30 ++-------------- .../pipelines/cv/virtual_tryon_pipeline.py | 8 ++--- modelscope/preprocessors/image.py | 36 ++++++++++++++++++- 13 files changed, 56 insertions(+), 116 deletions(-) diff --git a/modelscope/models/multi_modal/clip/clip_model.py b/modelscope/models/multi_modal/clip/clip_model.py index 79f7ae44..8dd36acf 100644 --- a/modelscope/models/multi_modal/clip/clip_model.py +++ b/modelscope/models/multi_modal/clip/clip_model.py @@ -1,6 +1,6 @@ -import os.path as osp from typing import Any, Dict +import cv2 import json import numpy as np import torch diff --git a/modelscope/pipelines/audio/ans_pipeline.py b/modelscope/pipelines/audio/ans_pipeline.py index 298f8bd8..80b6bae1 100644 --- a/modelscope/pipelines/audio/ans_pipeline.py +++ b/modelscope/pipelines/audio/ans_pipeline.py @@ -11,7 +11,6 @@ from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.utils.constant import Tasks -from modelscope.utils.torch_utils import create_device def audio_norm(x): diff --git a/modelscope/pipelines/base.py b/modelscope/pipelines/base.py index 8a2c13bc..7fda7018 100644 --- a/modelscope/pipelines/base.py +++ b/modelscope/pipelines/base.py @@ -8,7 +8,6 @@ from typing import Any, Dict, Generator, List, Mapping, Union import numpy as np -from modelscope.hub.snapshot_download import snapshot_download from modelscope.models.base import Model from modelscope.msdatasets import MsDataset from modelscope.outputs import TASK_OUTPUTS diff --git a/modelscope/pipelines/cv/animal_recog_pipeline.py b/modelscope/pipelines/cv/animal_recog_pipeline.py index 5cb752b5..3260ea6e 100644 --- a/modelscope/pipelines/cv/animal_recog_pipeline.py +++ b/modelscope/pipelines/cv/animal_recog_pipeline.py @@ -13,7 +13,7 @@ from modelscope.models.cv.animal_recognition import resnet from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import load_image +from modelscope.preprocessors import LoadImage, load_image from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger @@ -79,19 +79,7 @@ class AnimalRecogPipeline(Pipeline): logger.info('load model done') def preprocess(self, input: Input) -> Dict[str, Any]: - if isinstance(input, str): - img = load_image(input) - elif isinstance(input, PIL.Image.Image): - img = input.convert('RGB') - elif isinstance(input, np.ndarray): - if len(input.shape) == 2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - img = input[:, :, ::-1] - img = Image.fromarray(img.astype('uint8')).convert('RGB') - else: - raise TypeError(f'input should be either str, PIL.Image,' - f' np.array, but got {type(input)}') - + img = LoadImage.convert_to_img(input) normalize = transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) test_transforms = transforms.Compose([ diff --git a/modelscope/pipelines/cv/image_cartoon_pipeline.py b/modelscope/pipelines/cv/image_cartoon_pipeline.py index 2ea19d70..d351020f 100644 --- a/modelscope/pipelines/cv/image_cartoon_pipeline.py +++ b/modelscope/pipelines/cv/image_cartoon_pipeline.py @@ -3,7 +3,6 @@ from typing import Any, Dict import cv2 import numpy as np -import PIL import tensorflow as tf from modelscope.metainfo import Pipelines @@ -14,7 +13,7 @@ from modelscope.models.cv.cartoon.utils import get_f5p, padTo16x, resize_size from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import load_image +from modelscope.preprocessors import LoadImage from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger @@ -65,17 +64,7 @@ class ImageCartoonPipeline(Pipeline): return sess def preprocess(self, input: Input) -> Dict[str, Any]: - if isinstance(input, str): - img = np.array(load_image(input)) - elif isinstance(input, PIL.Image.Image): - img = np.array(input.convert('RGB')) - elif isinstance(input, np.ndarray): - if len(input.shape) == 2: - input = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) - img = input[:, :, ::-1] - else: - raise TypeError(f'input should be either str, PIL.Image,' - f' np.array, but got {type(input)}') + img = LoadImage.convert_to_ndarray(input) img = img.astype(np.float) result = {'img': img} return result diff --git a/modelscope/pipelines/cv/image_color_enhance_pipeline.py b/modelscope/pipelines/cv/image_color_enhance_pipeline.py index 506488f3..c6de89a4 100644 --- a/modelscope/pipelines/cv/image_color_enhance_pipeline.py +++ b/modelscope/pipelines/cv/image_color_enhance_pipeline.py @@ -13,7 +13,7 @@ from modelscope.models.cv.image_color_enhance.image_color_enhance import \ from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input from modelscope.preprocessors import (ImageColorEnhanceFinetunePreprocessor, - load_image) + LoadImage, load_image) from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger from ..base import Pipeline @@ -47,18 +47,7 @@ class ImageColorEnhancePipeline(Pipeline): self._device = torch.device('cpu') def preprocess(self, input: Input) -> Dict[str, Any]: - if isinstance(input, str): - img = load_image(input) - elif isinstance(input, PIL.Image.Image): - img = input.convert('RGB') - elif isinstance(input, np.ndarray): - if len(input.shape) == 2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) - img = Image.fromarray(img.astype('uint8')).convert('RGB') - else: - raise TypeError(f'input should be either str, PIL.Image,' - f' np.array, but got {type(input)}') - + img = LoadImage.convert_to_img(input) test_transforms = transforms.Compose([transforms.ToTensor()]) img = test_transforms(img) result = {'src': img.unsqueeze(0).to(self._device)} diff --git a/modelscope/pipelines/cv/image_colorization_pipeline.py b/modelscope/pipelines/cv/image_colorization_pipeline.py index 0e5cc3e1..8992ba8e 100644 --- a/modelscope/pipelines/cv/image_colorization_pipeline.py +++ b/modelscope/pipelines/cv/image_colorization_pipeline.py @@ -88,7 +88,7 @@ class ImageColorizationPipeline(Pipeline): img = input.convert('LA').convert('RGB') elif isinstance(input, np.ndarray): if len(input.shape) == 2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) img = input[:, :, ::-1] # in rgb order img = PIL.Image.fromarray(img).convert('LA').convert('RGB') else: diff --git a/modelscope/pipelines/cv/image_matting_pipeline.py b/modelscope/pipelines/cv/image_matting_pipeline.py index 91d3c515..691b1d94 100644 --- a/modelscope/pipelines/cv/image_matting_pipeline.py +++ b/modelscope/pipelines/cv/image_matting_pipeline.py @@ -3,13 +3,12 @@ from typing import Any, Dict import cv2 import numpy as np -import PIL from modelscope.metainfo import Pipelines from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import load_image +from modelscope.preprocessors import LoadImage from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger @@ -47,17 +46,7 @@ class ImageMattingPipeline(Pipeline): logger.info('load model done') def preprocess(self, input: Input) -> Dict[str, Any]: - if isinstance(input, str): - img = np.array(load_image(input)) - elif isinstance(input, PIL.Image.Image): - img = np.array(input.convert('RGB')) - elif isinstance(input, np.ndarray): - if len(input.shape) == 2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - img = input[:, :, ::-1] # in rgb order - else: - raise TypeError(f'input should be either str, PIL.Image,' - f' np.array, but got {type(input)}') + img = LoadImage.convert_to_img(input) img = img.astype(np.float) result = {'img': img} return result diff --git a/modelscope/pipelines/cv/image_super_resolution_pipeline.py b/modelscope/pipelines/cv/image_super_resolution_pipeline.py index 9f1dab62..86e4042f 100644 --- a/modelscope/pipelines/cv/image_super_resolution_pipeline.py +++ b/modelscope/pipelines/cv/image_super_resolution_pipeline.py @@ -10,7 +10,7 @@ from modelscope.models.cv.super_resolution import rrdbnet_arch from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import load_image +from modelscope.preprocessors import LoadImage, load_image from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger @@ -46,18 +46,7 @@ class ImageSuperResolutionPipeline(Pipeline): logger.info('load model done') def preprocess(self, input: Input) -> Dict[str, Any]: - if isinstance(input, str): - img = np.array(load_image(input)) - elif isinstance(input, PIL.Image.Image): - img = np.array(input.convert('RGB')) - elif isinstance(input, np.ndarray): - if len(input.shape) == 2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - img = input[:, :, ::-1] # in rgb order - else: - raise TypeError(f'input should be either str, PIL.Image,' - f' np.array, but got {type(input)}') - + img = LoadImage.convert_to_ndarray(input) img = torch.from_numpy(img).to(self.device).permute( 2, 0, 1).unsqueeze(0) / 255. result = {'img': img} diff --git a/modelscope/pipelines/cv/ocr_detection_pipeline.py b/modelscope/pipelines/cv/ocr_detection_pipeline.py index 4d842fbe..4ebfbc65 100644 --- a/modelscope/pipelines/cv/ocr_detection_pipeline.py +++ b/modelscope/pipelines/cv/ocr_detection_pipeline.py @@ -3,14 +3,13 @@ from typing import Any, Dict import cv2 import numpy as np -import PIL import tensorflow as tf from modelscope.metainfo import Pipelines from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import load_image +from modelscope.preprocessors import LoadImage from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils @@ -112,17 +111,8 @@ class OCRDetectionPipeline(Pipeline): model_loader.restore(sess, model_path) def preprocess(self, input: Input) -> Dict[str, Any]: - if isinstance(input, str): - img = np.array(load_image(input)) - elif isinstance(input, PIL.Image.Image): - img = np.array(input.convert('RGB')) - elif isinstance(input, np.ndarray): - if len(input.shape) == 2: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - img = input[:, :, ::-1] # in rgb order - else: - raise TypeError(f'input should be either str, PIL.Image,' - f' np.array, but got {type(input)}') + img = LoadImage.convert_to_ndarray(input) + h, w, c = img.shape img_pad = np.zeros((max(h, w), max(h, w), 3), dtype=np.float32) img_pad[:h, :w, :] = img diff --git a/modelscope/pipelines/cv/style_transfer_pipeline.py b/modelscope/pipelines/cv/style_transfer_pipeline.py index cb7ede3b..687f0d40 100644 --- a/modelscope/pipelines/cv/style_transfer_pipeline.py +++ b/modelscope/pipelines/cv/style_transfer_pipeline.py @@ -3,13 +3,12 @@ from typing import Any, Dict import cv2 import numpy as np -import PIL from modelscope.metainfo import Pipelines from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.preprocessors import load_image +from modelscope.preprocessors import LoadImage from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.logger import get_logger @@ -61,35 +60,12 @@ class StyleTransferPipeline(Pipeline): return pipeline_parameters, {}, {} def preprocess(self, content: Input, style: Input) -> Dict[str, Any]: - if isinstance(content, str): - content = np.array(load_image(content)) - elif isinstance(content, PIL.Image.Image): - content = np.array(content.convert('RGB')) - elif isinstance(content, np.ndarray): - if len(content.shape) == 2: - content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR) - content = content[:, :, ::-1] # in rgb order - else: - raise TypeError( - f'modelscope error: content should be either str, PIL.Image,' - f' np.array, but got {type(content)}') + content = LoadImage.convert_to_ndarray(content) if len(content.shape) == 2: content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR) content_img = content.astype(np.float) - if isinstance(style, str): - style_img = np.array(load_image(style)) - elif isinstance(style, PIL.Image.Image): - style_img = np.array(style.convert('RGB')) - elif isinstance(style, np.ndarray): - if len(style.shape) == 2: - style_img = cv2.cvtColor(style, cv2.COLOR_GRAY2BGR) - style_img = style_img[:, :, ::-1] # in rgb order - else: - raise TypeError( - f'modelscope error: style should be either str, PIL.Image,' - f' np.array, but got {type(style)}') - + style_img = LoadImage.convert_to_ndarray(style) if len(style_img.shape) == 2: style_img = cv2.cvtColor(style_img, cv2.COLOR_GRAY2BGR) style_img = style_img.astype(np.float) diff --git a/modelscope/pipelines/cv/virtual_tryon_pipeline.py b/modelscope/pipelines/cv/virtual_tryon_pipeline.py index 5d849ba2..c6577c35 100644 --- a/modelscope/pipelines/cv/virtual_tryon_pipeline.py +++ b/modelscope/pipelines/cv/virtual_tryon_pipeline.py @@ -1,21 +1,18 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp -from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Union +from typing import Any, Dict import cv2 import numpy as np import PIL import torch from PIL import Image -from torchvision import transforms from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Pipelines from modelscope.models.cv.virual_tryon.sdafnet import SDAFNet_Tryon -from modelscope.outputs import TASK_OUTPUTS, OutputKeys -from modelscope.pipelines.util import is_model, is_official_hub_path +from modelscope.outputs import OutputKeys from modelscope.preprocessors import load_image from modelscope.utils.constant import ModelFile, Tasks from ..base import Pipeline @@ -67,6 +64,7 @@ class VirtualTryonPipeline(Pipeline): load_pretrained(self.model, src_params) self.model = self.model.eval() self.size = 192 + from torchvision import transforms self.test_transforms = transforms.Compose([ transforms.Resize(self.size, interpolation=2), transforms.ToTensor(), diff --git a/modelscope/preprocessors/image.py b/modelscope/preprocessors/image.py index 85afb5b8..4c911f97 100644 --- a/modelscope/preprocessors/image.py +++ b/modelscope/preprocessors/image.py @@ -2,7 +2,10 @@ import io from typing import Any, Dict, Union -import torch +import cv2 +import numpy as np +import PIL +from numpy import ndarray from PIL import Image, ImageOps from modelscope.fileio import File @@ -60,6 +63,37 @@ class LoadImage: repr_str = f'{self.__class__.__name__}(' f'mode={self.mode})' return repr_str + @staticmethod + def convert_to_ndarray(input) -> ndarray: + if isinstance(input, str): + img = np.array(load_image(input)) + elif isinstance(input, PIL.Image.Image): + img = np.array(input.convert('RGB')) + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + input = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) + img = input[:, :, ::-1] + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + return img + + @staticmethod + def convert_to_img(input) -> ndarray: + if isinstance(input, str): + img = load_image(input) + elif isinstance(input, PIL.Image.Image): + img = input.convert('RGB') + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR) + img = input[:, :, ::-1] + img = Image.fromarray(img.astype('uint8')).convert('RGB') + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + return img + def load_image(image_path_or_url: str) -> Image.Image: """ simple interface to load an image from file or url