Browse Source

[to #43259593] refacor image preprocess

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9501913

    * [to #43259593] refacor image preprocess
master
yingda.chen 3 years ago
parent
commit
590bc52f97
13 changed files with 56 additions and 116 deletions
  1. +1
    -1
      modelscope/models/multi_modal/clip/clip_model.py
  2. +0
    -1
      modelscope/pipelines/audio/ans_pipeline.py
  3. +0
    -1
      modelscope/pipelines/base.py
  4. +2
    -14
      modelscope/pipelines/cv/animal_recog_pipeline.py
  5. +2
    -13
      modelscope/pipelines/cv/image_cartoon_pipeline.py
  6. +2
    -13
      modelscope/pipelines/cv/image_color_enhance_pipeline.py
  7. +1
    -1
      modelscope/pipelines/cv/image_colorization_pipeline.py
  8. +2
    -13
      modelscope/pipelines/cv/image_matting_pipeline.py
  9. +2
    -13
      modelscope/pipelines/cv/image_super_resolution_pipeline.py
  10. +3
    -13
      modelscope/pipelines/cv/ocr_detection_pipeline.py
  11. +3
    -27
      modelscope/pipelines/cv/style_transfer_pipeline.py
  12. +3
    -5
      modelscope/pipelines/cv/virtual_tryon_pipeline.py
  13. +35
    -1
      modelscope/preprocessors/image.py

+ 1
- 1
modelscope/models/multi_modal/clip/clip_model.py View File

@@ -1,6 +1,6 @@
import os.path as osp
from typing import Any, Dict from typing import Any, Dict


import cv2
import json import json
import numpy as np import numpy as np
import torch import torch


+ 0
- 1
modelscope/pipelines/audio/ans_pipeline.py View File

@@ -11,7 +11,6 @@ from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.torch_utils import create_device




def audio_norm(x): def audio_norm(x):


+ 0
- 1
modelscope/pipelines/base.py View File

@@ -8,7 +8,6 @@ from typing import Any, Dict, Generator, List, Mapping, Union


import numpy as np import numpy as np


from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.base import Model from modelscope.models.base import Model
from modelscope.msdatasets import MsDataset from modelscope.msdatasets import MsDataset
from modelscope.outputs import TASK_OUTPUTS from modelscope.outputs import TASK_OUTPUTS


+ 2
- 14
modelscope/pipelines/cv/animal_recog_pipeline.py View File

@@ -13,7 +13,7 @@ from modelscope.models.cv.animal_recognition import resnet
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.preprocessors import LoadImage, load_image
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


@@ -79,19 +79,7 @@ class AnimalRecogPipeline(Pipeline):
logger.info('load model done') logger.info('load model done')


def preprocess(self, input: Input) -> Dict[str, Any]: def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = load_image(input)
elif isinstance(input, PIL.Image.Image):
img = input.convert('RGB')
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1]
img = Image.fromarray(img.astype('uint8')).convert('RGB')
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')

img = LoadImage.convert_to_img(input)
normalize = transforms.Normalize( normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
test_transforms = transforms.Compose([ test_transforms = transforms.Compose([


+ 2
- 13
modelscope/pipelines/cv/image_cartoon_pipeline.py View File

@@ -3,7 +3,6 @@ from typing import Any, Dict


import cv2 import cv2
import numpy as np import numpy as np
import PIL
import tensorflow as tf import tensorflow as tf


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
@@ -14,7 +13,7 @@ from modelscope.models.cv.cartoon.utils import get_f5p, padTo16x, resize_size
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


@@ -65,17 +64,7 @@ class ImageCartoonPipeline(Pipeline):
return sess return sess


def preprocess(self, input: Input) -> Dict[str, Any]: def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = np.array(load_image(input))
elif isinstance(input, PIL.Image.Image):
img = np.array(input.convert('RGB'))
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
input = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1]
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
img = LoadImage.convert_to_ndarray(input)
img = img.astype(np.float) img = img.astype(np.float)
result = {'img': img} result = {'img': img}
return result return result


+ 2
- 13
modelscope/pipelines/cv/image_color_enhance_pipeline.py View File

@@ -13,7 +13,7 @@ from modelscope.models.cv.image_color_enhance.image_color_enhance import \
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input from modelscope.pipelines.base import Input
from modelscope.preprocessors import (ImageColorEnhanceFinetunePreprocessor, from modelscope.preprocessors import (ImageColorEnhanceFinetunePreprocessor,
load_image)
LoadImage, load_image)
from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from ..base import Pipeline from ..base import Pipeline
@@ -47,18 +47,7 @@ class ImageColorEnhancePipeline(Pipeline):
self._device = torch.device('cpu') self._device = torch.device('cpu')


def preprocess(self, input: Input) -> Dict[str, Any]: def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = load_image(input)
elif isinstance(input, PIL.Image.Image):
img = input.convert('RGB')
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
img = Image.fromarray(img.astype('uint8')).convert('RGB')
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')

img = LoadImage.convert_to_img(input)
test_transforms = transforms.Compose([transforms.ToTensor()]) test_transforms = transforms.Compose([transforms.ToTensor()])
img = test_transforms(img) img = test_transforms(img)
result = {'src': img.unsqueeze(0).to(self._device)} result = {'src': img.unsqueeze(0).to(self._device)}


+ 1
- 1
modelscope/pipelines/cv/image_colorization_pipeline.py View File

@@ -88,7 +88,7 @@ class ImageColorizationPipeline(Pipeline):
img = input.convert('LA').convert('RGB') img = input.convert('LA').convert('RGB')
elif isinstance(input, np.ndarray): elif isinstance(input, np.ndarray):
if len(input.shape) == 2: if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1] # in rgb order img = input[:, :, ::-1] # in rgb order
img = PIL.Image.fromarray(img).convert('LA').convert('RGB') img = PIL.Image.fromarray(img).convert('LA').convert('RGB')
else: else:


+ 2
- 13
modelscope/pipelines/cv/image_matting_pipeline.py View File

@@ -3,13 +3,12 @@ from typing import Any, Dict


import cv2 import cv2
import numpy as np import numpy as np
import PIL


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


@@ -47,17 +46,7 @@ class ImageMattingPipeline(Pipeline):
logger.info('load model done') logger.info('load model done')


def preprocess(self, input: Input) -> Dict[str, Any]: def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = np.array(load_image(input))
elif isinstance(input, PIL.Image.Image):
img = np.array(input.convert('RGB'))
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1] # in rgb order
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
img = LoadImage.convert_to_img(input)
img = img.astype(np.float) img = img.astype(np.float)
result = {'img': img} result = {'img': img}
return result return result


+ 2
- 13
modelscope/pipelines/cv/image_super_resolution_pipeline.py View File

@@ -10,7 +10,7 @@ from modelscope.models.cv.super_resolution import rrdbnet_arch
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.preprocessors import LoadImage, load_image
from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


@@ -46,18 +46,7 @@ class ImageSuperResolutionPipeline(Pipeline):
logger.info('load model done') logger.info('load model done')


def preprocess(self, input: Input) -> Dict[str, Any]: def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = np.array(load_image(input))
elif isinstance(input, PIL.Image.Image):
img = np.array(input.convert('RGB'))
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1] # in rgb order
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')

img = LoadImage.convert_to_ndarray(input)
img = torch.from_numpy(img).to(self.device).permute( img = torch.from_numpy(img).to(self.device).permute(
2, 0, 1).unsqueeze(0) / 255. 2, 0, 1).unsqueeze(0) / 255.
result = {'img': img} result = {'img': img}


+ 3
- 13
modelscope/pipelines/cv/ocr_detection_pipeline.py View File

@@ -3,14 +3,13 @@ from typing import Any, Dict


import cv2 import cv2
import numpy as np import numpy as np
import PIL
import tensorflow as tf import tensorflow as tf


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils from .ocr_utils import model_resnet_mutex_v4_linewithchar, ops, utils
@@ -112,17 +111,8 @@ class OCRDetectionPipeline(Pipeline):
model_loader.restore(sess, model_path) model_loader.restore(sess, model_path)


def preprocess(self, input: Input) -> Dict[str, Any]: def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = np.array(load_image(input))
elif isinstance(input, PIL.Image.Image):
img = np.array(input.convert('RGB'))
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1] # in rgb order
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
img = LoadImage.convert_to_ndarray(input)

h, w, c = img.shape h, w, c = img.shape
img_pad = np.zeros((max(h, w), max(h, w), 3), dtype=np.float32) img_pad = np.zeros((max(h, w), max(h, w), 3), dtype=np.float32)
img_pad[:h, :w, :] = img img_pad[:h, :w, :] = img


+ 3
- 27
modelscope/pipelines/cv/style_transfer_pipeline.py View File

@@ -3,13 +3,12 @@ from typing import Any, Dict


import cv2 import cv2
import numpy as np import numpy as np
import PIL


from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import load_image
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger


@@ -61,35 +60,12 @@ class StyleTransferPipeline(Pipeline):
return pipeline_parameters, {}, {} return pipeline_parameters, {}, {}


def preprocess(self, content: Input, style: Input) -> Dict[str, Any]: def preprocess(self, content: Input, style: Input) -> Dict[str, Any]:
if isinstance(content, str):
content = np.array(load_image(content))
elif isinstance(content, PIL.Image.Image):
content = np.array(content.convert('RGB'))
elif isinstance(content, np.ndarray):
if len(content.shape) == 2:
content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR)
content = content[:, :, ::-1] # in rgb order
else:
raise TypeError(
f'modelscope error: content should be either str, PIL.Image,'
f' np.array, but got {type(content)}')
content = LoadImage.convert_to_ndarray(content)
if len(content.shape) == 2: if len(content.shape) == 2:
content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR) content = cv2.cvtColor(content, cv2.COLOR_GRAY2BGR)
content_img = content.astype(np.float) content_img = content.astype(np.float)


if isinstance(style, str):
style_img = np.array(load_image(style))
elif isinstance(style, PIL.Image.Image):
style_img = np.array(style.convert('RGB'))
elif isinstance(style, np.ndarray):
if len(style.shape) == 2:
style_img = cv2.cvtColor(style, cv2.COLOR_GRAY2BGR)
style_img = style_img[:, :, ::-1] # in rgb order
else:
raise TypeError(
f'modelscope error: style should be either str, PIL.Image,'
f' np.array, but got {type(style)}')

style_img = LoadImage.convert_to_ndarray(style)
if len(style_img.shape) == 2: if len(style_img.shape) == 2:
style_img = cv2.cvtColor(style_img, cv2.COLOR_GRAY2BGR) style_img = cv2.cvtColor(style_img, cv2.COLOR_GRAY2BGR)
style_img = style_img.astype(np.float) style_img = style_img.astype(np.float)


+ 3
- 5
modelscope/pipelines/cv/virtual_tryon_pipeline.py View File

@@ -1,21 +1,18 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.


import os.path as osp import os.path as osp
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Union
from typing import Any, Dict


import cv2 import cv2
import numpy as np import numpy as np
import PIL import PIL
import torch import torch
from PIL import Image from PIL import Image
from torchvision import transforms


from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Pipelines from modelscope.metainfo import Pipelines
from modelscope.models.cv.virual_tryon.sdafnet import SDAFNet_Tryon from modelscope.models.cv.virual_tryon.sdafnet import SDAFNet_Tryon
from modelscope.outputs import TASK_OUTPUTS, OutputKeys
from modelscope.pipelines.util import is_model, is_official_hub_path
from modelscope.outputs import OutputKeys
from modelscope.preprocessors import load_image from modelscope.preprocessors import load_image
from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.constant import ModelFile, Tasks
from ..base import Pipeline from ..base import Pipeline
@@ -67,6 +64,7 @@ class VirtualTryonPipeline(Pipeline):
load_pretrained(self.model, src_params) load_pretrained(self.model, src_params)
self.model = self.model.eval() self.model = self.model.eval()
self.size = 192 self.size = 192
from torchvision import transforms
self.test_transforms = transforms.Compose([ self.test_transforms = transforms.Compose([
transforms.Resize(self.size, interpolation=2), transforms.Resize(self.size, interpolation=2),
transforms.ToTensor(), transforms.ToTensor(),


+ 35
- 1
modelscope/preprocessors/image.py View File

@@ -2,7 +2,10 @@
import io import io
from typing import Any, Dict, Union from typing import Any, Dict, Union


import torch
import cv2
import numpy as np
import PIL
from numpy import ndarray
from PIL import Image, ImageOps from PIL import Image, ImageOps


from modelscope.fileio import File from modelscope.fileio import File
@@ -60,6 +63,37 @@ class LoadImage:
repr_str = f'{self.__class__.__name__}(' f'mode={self.mode})' repr_str = f'{self.__class__.__name__}(' f'mode={self.mode})'
return repr_str return repr_str


@staticmethod
def convert_to_ndarray(input) -> ndarray:
if isinstance(input, str):
img = np.array(load_image(input))
elif isinstance(input, PIL.Image.Image):
img = np.array(input.convert('RGB'))
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
input = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1]
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
return img

@staticmethod
def convert_to_img(input) -> ndarray:
if isinstance(input, str):
img = load_image(input)
elif isinstance(input, PIL.Image.Image):
img = input.convert('RGB')
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1]
img = Image.fromarray(img.astype('uint8')).convert('RGB')
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
return img



def load_image(image_path_or_url: str) -> Image.Image: def load_image(image_path_or_url: str) -> Image.Image:
""" simple interface to load an image from file or url """ simple interface to load an image from file or url


Loading…
Cancel
Save