baiguan.yt 3 years ago
parent
commit
14a62d401e
10 changed files with 831 additions and 0 deletions
  1. +3
    -0
      data/test/images/marilyn_monroe_4.jpg
  2. +1
    -0
      modelscope/metainfo.py
  3. +300
    -0
      modelscope/models/cv/image_colorization/unet.py
  4. +348
    -0
      modelscope/models/cv/image_colorization/utils.py
  5. +1
    -0
      modelscope/outputs.py
  6. +2
    -0
      modelscope/pipelines/builder.py
  7. +1
    -0
      modelscope/pipelines/cv/__init__.py
  8. +132
    -0
      modelscope/pipelines/cv/image_colorization_pipeline.py
  9. +1
    -0
      modelscope/utils/constant.py
  10. +42
    -0
      tests/pipelines/test_image_colorization.py

+ 3
- 0
data/test/images/marilyn_monroe_4.jpg View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8b425fb89442e4c6c32c71c17c1c1afef8a2c5bc9ec9529b5a0fc21c53e1a02b
size 39248

+ 1
- 0
modelscope/metainfo.py View File

@@ -50,6 +50,7 @@ class Pipelines(object):
action_recognition = 'TAdaConv_action-recognition' action_recognition = 'TAdaConv_action-recognition'
animal_recognation = 'resnet101-animal_recog' animal_recognation = 'resnet101-animal_recog'
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
image_colorization = 'unet-image-colorization'
image_super_resolution = 'rrdb-image-super-resolution' image_super_resolution = 'rrdb-image-super-resolution'
face_image_generation = 'gan-face-image-generation' face_image_generation = 'gan-face-image-generation'
style_transfer = 'AAMS-style-transfer' style_transfer = 'AAMS-style-transfer'


+ 300
- 0
modelscope/models/cv/image_colorization/unet.py View File

@@ -0,0 +1,300 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm, weight_norm

from .utils import (MergeLayer, NormType, PixelShuffle_ICNR, SelfAttention,
SequentialEx, SigmoidRange, dummy_eval, hook_outputs,
in_channels, model_sizes, relu, res_block)

__all__ = ['DynamicUnetDeep', 'DynamicUnetWide']


def custom_conv_layer(
ni,
nf,
ks=3,
stride=1,
padding=None,
bias=None,
is_1d=False,
norm_type=NormType.Batch,
use_activ=True,
leaky=None,
transpose=False,
init=nn.init.kaiming_normal_,
self_attention=False,
extra_bn=False,
):
'Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers.'
if padding is None:
padding = (ks - 1) // 2 if not transpose else 0
bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn is True
if bias is None:
bias = not bn
conv_func = nn.ConvTranspose2d if transpose is True else nn.Conv1d
conv_func = conv_func if is_1d else nn.Conv2d
conv = conv_func(
ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding)
if norm_type == NormType.Weight:
conv = weight_norm(conv)
elif norm_type == NormType.Spectral:
conv = spectral_norm(conv)

layers = [conv]
if use_activ:
layers.append(relu(True, leaky=leaky))
if bn:
layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
if self_attention:
layers.append(SelfAttention(nf))
return nn.Sequential(*layers)


def _get_sfs_idxs(sizes):
'Get the indexes of the layers where the size of the activation changes.'
feature_szs = [size[-1] for size in sizes]
sfs_idxs = list(
np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
if feature_szs[0] != feature_szs[1]:
sfs_idxs = [0] + sfs_idxs
return sfs_idxs


class CustomPixelShuffle_ICNR(nn.Module):
'Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, and `weight_norm`.'

def __init__(self, ni, nf=None, scale=2, blur=False, leaky=None, **kwargs):
super().__init__()
nf = ni if nf is None else nf
self.conv = custom_conv_layer(
ni, nf * (scale**2), ks=1, use_activ=False, **kwargs)
self.shuf = nn.PixelShuffle(scale)
# Blurring over (h*w) kernel
# "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
# - https://arxiv.org/abs/1806.02658
self.pad = nn.ReplicationPad2d((1, 0, 1, 0))
self.blur = nn.AvgPool2d(2, stride=1)
self.relu = relu(True, leaky=leaky)

def forward(self, x):
x = self.shuf(self.relu(self.conv(x)))
return self.blur(self.pad(x)) if self.blur else x


class UnetBlockDeep(nn.Module):
'A quasi-UNet block, using `PixelShuffle_ICNR upsampling`.'

def __init__(self,
up_in_c,
x_in_c,
hook,
final_div=True,
blur=False,
leaky=None,
self_attention=False,
nf_factor=1.0,
**kwargs):
super().__init__()
self.hook = hook
self.shuf = CustomPixelShuffle_ICNR(
up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs)
self.bn = nn.BatchNorm2d(x_in_c)
ni = up_in_c // 2 + x_in_c
nf = int((ni if final_div else ni // 2) * nf_factor)
self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs)
self.conv2 = custom_conv_layer(
nf, nf, leaky=leaky, self_attention=self_attention, **kwargs)
self.relu = relu(leaky=leaky)

def forward(self, up_in):
s = self.hook.stored
up_out = self.shuf(up_in)
ssh = s.shape[-2:]
if ssh != up_out.shape[-2:]:
up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
return self.conv2(self.conv1(cat_x))


class DynamicUnetDeep(SequentialEx):
'Create a U-Net from a given architecture.'

def __init__(self,
encoder,
n_classes,
blur=False,
blur_final=True,
self_attention=False,
y_range=None,
last_cross=True,
bottle=False,
norm_type=NormType.Batch,
nf_factor=1.0,
**kwargs):
extra_bn = norm_type == NormType.Spectral
imsize = (256, 256)
sfs_szs = model_sizes(encoder, size=imsize)
sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
x = dummy_eval(encoder, imsize).detach()

ni = sfs_szs[-1][1]
middle_conv = nn.Sequential(
custom_conv_layer(
ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
custom_conv_layer(
ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
).eval()
x = middle_conv(x)
layers = [encoder, nn.BatchNorm2d(ni), nn.ReLU(), middle_conv]

for i, idx in enumerate(sfs_idxs):
not_final = i != len(sfs_idxs) - 1
up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
sa = self_attention and (i == len(sfs_idxs) - 3)
unet_block = UnetBlockDeep(
up_in_c,
x_in_c,
self.sfs[i],
final_div=not_final,
blur=blur,
self_attention=sa,
norm_type=norm_type,
extra_bn=extra_bn,
nf_factor=nf_factor,
**kwargs).eval()
layers.append(unet_block)
x = unet_block(x)

ni = x.shape[1]
if imsize != sfs_szs[0][-2:]:
layers.append(PixelShuffle_ICNR(ni, **kwargs))
if last_cross:
layers.append(MergeLayer(dense=True))
ni += in_channels(encoder)
layers.append(
res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
layers += [
custom_conv_layer(
ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
]
if y_range is not None:
layers.append(SigmoidRange(*y_range))
super().__init__(*layers)

def __del__(self):
if hasattr(self, 'sfs'):
self.sfs.remove()


# ------------------------------------------------------
class UnetBlockWide(nn.Module):
'A quasi-UNet block, using `PixelShuffle_ICNR upsampling`.'

def __init__(self,
up_in_c,
x_in_c,
n_out,
hook,
final_div=True,
blur=False,
leaky=None,
self_attention=False,
**kwargs):
super().__init__()
self.hook = hook
up_out = x_out = n_out // 2
self.shuf = CustomPixelShuffle_ICNR(
up_in_c, up_out, blur=blur, leaky=leaky, **kwargs)
self.bn = nn.BatchNorm2d(x_in_c)
ni = up_out + x_in_c
self.conv = custom_conv_layer(
ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs)
self.relu = relu(leaky=leaky)

def forward(self, up_in):
s = self.hook.stored
up_out = self.shuf(up_in)
ssh = s.shape[-2:]
if ssh != up_out.shape[-2:]:
up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
return self.conv(cat_x)


class DynamicUnetWide(SequentialEx):
'Create a U-Net from a given architecture.'

def __init__(self,
encoder,
n_classes,
blur=False,
blur_final=True,
self_attention=False,
y_range=None,
last_cross=True,
bottle=False,
norm_type=NormType.Batch,
nf_factor=1,
**kwargs):

nf = 512 * nf_factor
extra_bn = norm_type == NormType.Spectral
imsize = (256, 256)
sfs_szs = model_sizes(encoder, size=imsize)
sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False)
x = dummy_eval(encoder, imsize).detach()

ni = sfs_szs[-1][1]
middle_conv = nn.Sequential(
custom_conv_layer(
ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
custom_conv_layer(
ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
).eval()
x = middle_conv(x)
layers = [encoder, nn.BatchNorm2d(ni), nn.ReLU(), middle_conv]

for i, idx in enumerate(sfs_idxs):
not_final = i != len(sfs_idxs) - 1
up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
sa = self_attention and (i == len(sfs_idxs) - 3)

n_out = nf if not_final else nf // 2

unet_block = UnetBlockWide(
up_in_c,
x_in_c,
n_out,
self.sfs[i],
final_div=not_final,
blur=blur,
self_attention=sa,
norm_type=norm_type,
extra_bn=extra_bn,
**kwargs).eval()
layers.append(unet_block)
x = unet_block(x)

ni = x.shape[1]
if imsize != sfs_szs[0][-2:]:
layers.append(PixelShuffle_ICNR(ni, **kwargs))
if last_cross:
layers.append(MergeLayer(dense=True))
ni += in_channels(encoder)
layers.append(
res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
layers += [
custom_conv_layer(
ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)
]
if y_range is not None:
layers.append(SigmoidRange(*y_range))
super().__init__(*layers)

def __del__(self):
if hasattr(self, 'sfs'):
self.sfs.remove()

+ 348
- 0
modelscope/models/cv/image_colorization/utils.py View File

@@ -0,0 +1,348 @@
import functools
from enum import Enum

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm, weight_norm

NormType = Enum('NormType',
'Batch BatchZero Weight Spectral Group Instance SpectralGN')


def is_listy(x):
return isinstance(x, (tuple, list))


class Hook():
'Create a hook on `m` with `hook_func`.'

def __init__(self, m, hook_func, is_forward=True, detach=True):
self.hook_func, self.detach, self.stored = hook_func, detach, None
f = m.register_forward_hook if is_forward else m.register_backward_hook
self.hook = f(self.hook_fn)
self.removed = False

def hook_fn(self, module, input, output):
'Applies `hook_func` to `module`, `input`, `output`.'
if self.detach:
input = (o.detach()
for o in input) if is_listy(input) else input.detach()
output = (
o.detach()
for o in output) if is_listy(output) else output.detach()
self.stored = self.hook_func(module, input, output)

def remove(self):
'Remove the hook from the model.'
if not self.removed:
self.hook.remove()
self.removed = True

def __enter__(self, *args):
return self

def __exit__(self, *args):
self.remove()


class Hooks():
'Create several hooks on the modules in `ms` with `hook_func`.'

def __init__(self, ms, hook_func, is_forward=True, detach=True):
self.hooks = [Hook(m, hook_func, is_forward, detach) for m in ms]

def __getitem__(self, i):
return self.hooks[i]

def __len__(self):
return len(self.hooks)

def __iter__(self):
return iter(self.hooks)

@property
def stored(self):
return [o.stored for o in self]

def remove(self):
'Remove the hooks from the model.'
for h in self.hooks:
h.remove()

def __enter__(self, *args):
return self

def __exit__(self, *args):
self.remove()


def _hook_inner(m, i, o):
return o if isinstance(o, torch.Tensor) else o if is_listy(o) else list(o)


def hook_outputs(modules, detach=True, grad=False):
'Return `Hooks` that store activations of all `modules` in `self.stored`'
return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad)


def one_param(m):
'Return the first parameter of `m`.'
return next(m.parameters())


def dummy_batch(m, size=(64, 64)):
'Create a dummy batch to go through `m` with `size`.'
ch_in = in_channels(m)
return one_param(m).new(1, ch_in,
*size).requires_grad_(False).uniform_(-1., 1.)


def dummy_eval(m, size=(64, 64)):
'Pass a `dummy_batch` in evaluation mode in `m` with `size`.'
return m.eval()(dummy_batch(m, size))


def model_sizes(m, size=(64, 64)):
'Pass a dummy input through the model `m` to get the various sizes of activations.'
with hook_outputs(m) as hooks:
dummy_eval(m, size)
return [o.stored.shape for o in hooks]


class PrePostInitMeta(type):
'A metaclass that calls optional `__pre_init__` and `__post_init__` methods'

def __new__(cls, name, bases, dct):
x = super().__new__(cls, name, bases, dct)
old_init = x.__init__

def _pass(self):
pass

@functools.wraps(old_init)
def _init(self, *args, **kwargs):
self.__pre_init__()
old_init(self, *args, **kwargs)
self.__post_init__()

x.__init__ = _init
if not hasattr(x, '__pre_init__'):
x.__pre_init__ = _pass
if not hasattr(x, '__post_init__'):
x.__post_init__ = _pass
return x


class Module(nn.Module, metaclass=PrePostInitMeta):
'Same as `nn.Module`, but no need for subclasses to call `super().__init__`'

def __pre_init__(self):
super().__init__()

def __init__(self):
pass


def children(m):
'Get children of `m`.'
return list(m.children())


def num_children(m):
'Get number of children modules in `m`.'
return len(children(m))


def children_and_parameters(m: nn.Module):
'Return the children of `m` and its direct parameters not registered in modules.'
children = list(m.children())
children_p = sum([[id(p) for p in c.parameters()] for c in m.children()],
[])
for p in m.parameters():
if id(p) not in children_p:
children.append(ParameterModule(p))
return children


def flatten_model(m):
if num_children(m):
mapped = map(flatten_model, children_and_parameters(m))
return sum(mapped, [])
else:
return [m]


def in_channels(m):
'Return the shape of the first weight layer in `m`.'
for layer in flatten_model(m):
if hasattr(layer, 'weight'):
return layer.weight.shape[1]
raise Exception('No weight layer')


def relu(inplace: bool = False, leaky: float = None):
'Return a relu activation, maybe `leaky` and `inplace`.'
return nn.LeakyReLU(
inplace=inplace,
negative_slope=leaky) if leaky is not None else nn.ReLU(
inplace=inplace)


def conv_layer(ni,
nf,
ks=3,
stride=1,
padding=None,
bias=None,
is_1d=False,
norm_type=NormType.Batch,
use_activ=True,
leaky=None,
transpose=False,
init=nn.init.kaiming_normal_,
self_attention=False):
'Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers.'
if padding is None:
padding = (ks - 1) // 2 if not transpose else 0
bn = norm_type in (NormType.Batch, NormType.BatchZero)
if bias is None:
bias = not bn
conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
conv = conv_func(
ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding)
if norm_type == NormType.Weight:
conv = weight_norm(conv)
elif norm_type == NormType.Spectral:
conv = spectral_norm(conv)
layers = [conv]
if use_activ:
layers.append(relu(True, leaky=leaky))
if bn:
layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
if self_attention:
layers.append(SelfAttention(nf))
return nn.Sequential(*layers)


def res_block(nf,
dense=False,
norm_type=NormType.Batch,
bottle=False,
**conv_kwargs):
'Resnet block of `nf` features. `conv_kwargs` are passed to `conv_layer`.'
norm2 = norm_type
if not dense and (norm_type == NormType.Batch):
norm2 = NormType.BatchZero
nf_inner = nf // 2 if bottle else nf
return SequentialEx(
conv_layer(nf, nf_inner, norm_type=norm_type, **conv_kwargs),
conv_layer(nf_inner, nf, norm_type=norm2, **conv_kwargs),
MergeLayer(dense))


def conv1d(ni, no, ks=1, stride=1, padding=0, bias=False):
'Create and initialize a `nn.Conv1d` layer with spectral normalization.'
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
nn.init.kaiming_normal_(conv.weight)
if bias:
conv.bias.data.zero_()
return spectral_norm(conv)


class SelfAttention(Module):
'Self attention layer for nd.'

def __init__(self, n_channels):
self.query = conv1d(n_channels, n_channels // 8)
self.key = conv1d(n_channels, n_channels // 8)
self.value = conv1d(n_channels, n_channels)
self.gamma = nn.Parameter(torch.tensor([0.]))

def forward(self, x):
'Notation from https://arxiv.org/pdf/1805.08318.pdf'
size = x.size()
x = x.view(*size[:2], -1)
f, g, h = self.query(x), self.key(x), self.value(x)
beta = F.softmax(torch.bmm(f.permute(0, 2, 1).contiguous(), g), dim=1)
o = self.gamma * torch.bmm(h, beta) + x
return o.view(*size).contiguous()


def sigmoid_range(x, low, high):
'Sigmoid function with range `(low, high)`'
return torch.sigmoid(x) * (high - low) + low


class SigmoidRange(Module):
'Sigmoid module with range `(low,x_max)`'

def __init__(self, low, high):
self.low, self.high = low, high

def forward(self, x):
return sigmoid_range(x, self.low, self.high)


class SequentialEx(Module):
'Like `nn.Sequential`, but with ModuleList semantics, and can access module input'

def __init__(self, *layers):
self.layers = nn.ModuleList(layers)

def forward(self, x):
res = x
for layer in self.layers:
res.orig = x
nres = layer(res)
res.orig = None
res = nres
return res

def __getitem__(self, i):
return self.layers[i]

def append(self, layer):
return self.layers.append(layer)

def extend(self, layer):
return self.layers.extend(layer)

def insert(self, i, layer):
return self.layers.insert(i, layer)


class MergeLayer(Module):
'Merge a shortcut with the result of the module by adding them or concatenating thme if `dense=True`.'

def __init__(self, dense: bool = False):
self.dense = dense

def forward(self, x):
return torch.cat([x, x.orig], dim=1) if self.dense else (x + x.orig)


class PixelShuffle_ICNR(Module):
'Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, and `weight_norm`.'

def __init__(self,
ni: int,
nf: int = None,
scale: int = 2,
blur: bool = False,
norm_type=NormType.Weight,
leaky: float = None):
nf = ni if nf is None else nf
self.conv = conv_layer(
ni, nf * (scale**2), ks=1, norm_type=norm_type, use_activ=False)
self.shuf = nn.PixelShuffle(scale)
# Blurring over (h*w) kernel
# "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
# - https://arxiv.org/abs/1806.02658
self.pad = nn.ReplicationPad2d((1, 0, 1, 0))
self.blur = nn.AvgPool2d(2, stride=1)
self.relu = relu(True, leaky=leaky)

def forward(self, x):
x = self.shuf(self.relu(self.conv(x)))
return self.blur(self.pad(x)) if self.blur else x

+ 1
- 0
modelscope/outputs.py View File

@@ -70,6 +70,7 @@ TASK_OUTPUTS = {
Tasks.image_editing: [OutputKeys.OUTPUT_IMG], Tasks.image_editing: [OutputKeys.OUTPUT_IMG],
Tasks.image_matting: [OutputKeys.OUTPUT_IMG], Tasks.image_matting: [OutputKeys.OUTPUT_IMG],
Tasks.image_generation: [OutputKeys.OUTPUT_IMG], Tasks.image_generation: [OutputKeys.OUTPUT_IMG],
Tasks.image_colorization: [OutputKeys.OUTPUT_IMG],
Tasks.face_image_generation: [OutputKeys.OUTPUT_IMG], Tasks.face_image_generation: [OutputKeys.OUTPUT_IMG],
Tasks.image_super_resolution: [OutputKeys.OUTPUT_IMG], Tasks.image_super_resolution: [OutputKeys.OUTPUT_IMG],




+ 2
- 0
modelscope/pipelines/builder.py View File

@@ -68,6 +68,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.text_to_image_synthesis: Tasks.text_to_image_synthesis:
(Pipelines.text_to_image_synthesis, (Pipelines.text_to_image_synthesis,
'damo/cv_imagen_text-to-image-synthesis_tiny'), 'damo/cv_imagen_text-to-image-synthesis_tiny'),
Tasks.image_colorization: (Pipelines.image_colorization,
'damo/cv_unet_image-colorization'),
Tasks.style_transfer: (Pipelines.style_transfer, Tasks.style_transfer: (Pipelines.style_transfer,
'damo/cv_aams_style-transfer_damo'), 'damo/cv_aams_style-transfer_damo'),
Tasks.face_image_generation: (Pipelines.face_image_generation, Tasks.face_image_generation: (Pipelines.face_image_generation,


+ 1
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -6,6 +6,7 @@ try:
from .action_recognition_pipeline import ActionRecognitionPipeline from .action_recognition_pipeline import ActionRecognitionPipeline
from .animal_recog_pipeline import AnimalRecogPipeline from .animal_recog_pipeline import AnimalRecogPipeline
from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline
from .image_colorization_pipeline import ImageColorizationPipeline
from .image_super_resolution_pipeline import ImageSuperResolutionPipeline from .image_super_resolution_pipeline import ImageSuperResolutionPipeline
from .face_image_generation_pipeline import FaceImageGenerationPipeline from .face_image_generation_pipeline import FaceImageGenerationPipeline
except ModuleNotFoundError as e: except ModuleNotFoundError as e:


+ 132
- 0
modelscope/pipelines/cv/image_colorization_pipeline.py View File

@@ -0,0 +1,132 @@
from typing import Any, Dict

import cv2
import numpy as np
import PIL
import torch
from torchvision import models, transforms

from modelscope.metainfo import Pipelines
from modelscope.models.cv.image_colorization import unet
from modelscope.models.cv.image_colorization.utils import NormType
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input
from modelscope.preprocessors import load_image
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES

logger = get_logger()


@PIPELINES.register_module(
Tasks.image_colorization, module_name=Pipelines.image_colorization)
class ImageColorizationPipeline(Pipeline):

def __init__(self, model: str):
"""
use `model` to create a kws pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model)
self.device = 'cuda'
self.cut = 8
self.size = 1024 if self.device == 'cpu' else 512
self.orig_img = None
self.model_type = 'stable'
self.norm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.denorm = transforms.Normalize(
mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
std=[1 / 0.229, 1 / 0.224, 1 / 0.225])

if self.model_type == 'stable':
body = models.resnet101(pretrained=True)
body = torch.nn.Sequential(*list(body.children())[:self.cut])
self.model = unet.DynamicUnetWide(
body,
n_classes=3,
blur=True,
blur_final=True,
self_attention=True,
y_range=(-3.0, 3.0),
norm_type=NormType.Spectral,
last_cross=True,
bottle=False,
nf_factor=2,
).to(self.device)
else:
body = models.resnet34(pretrained=True)
body = torch.nn.Sequential(*list(body.children())[:cut])
model = unet.DynamicUnetDeep(
body,
n_classes=3,
blur=True,
blur_final=True,
self_attention=True,
y_range=(-3.0, 3.0),
norm_type=NormType.Spectral,
last_cross=True,
bottle=False,
nf_factor=1.5,
).to(self.device)

model_path = f'{model}/{ModelFile.TORCH_MODEL_FILE}'
self.model.load_state_dict(
torch.load(model_path)['model'], strict=True)

logger.info('load model done')

def preprocess(self, input: Input) -> Dict[str, Any]:
if isinstance(input, str):
img = load_image(input).convert('LA').convert('RGB')
elif isinstance(input, PIL.Image.Image):
img = input.convert('LA').convert('RGB')
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1] # in rgb order
img = PIL.Image.fromarray(img).convert('LA').convert('RGB')
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')

self.wide, self.height = img.size
if self.wide * self.height > self.size * self.size:
self.orig_img = img.copy()
img = img.resize((self.size, self.size),
resample=PIL.Image.BILINEAR)

img = self.norm(img).unsqueeze(0).to(self.device)
result = {'img': img}

return result

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
self.model.eval()
with torch.no_grad():
out = self.model(input['img'])[0]

out = self.denorm(out)
out = out.float().clamp(min=0, max=1)
out_img = (out.permute(1, 2, 0).flip(2).cpu().numpy() * 255).astype(
np.uint8)

if self.orig_img is not None:
color_np = cv2.resize(out_img, self.orig_img.size)
orig_np = np.asarray(self.orig_img)
color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV)
orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV)
hires = np.copy(orig_yuv)
hires[:, :, 1:3] = color_yuv[:, :, 1:3]
out_img = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR)

return {OutputKeys.OUTPUT_IMG: out_img.astype(np.uint8)}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 1
- 0
modelscope/utils/constant.py View File

@@ -27,6 +27,7 @@ class CVTasks(object):
ocr_detection = 'ocr-detection' ocr_detection = 'ocr-detection'
action_recognition = 'action-recognition' action_recognition = 'action-recognition'
video_embedding = 'video-embedding' video_embedding = 'video-embedding'
image_colorization = 'image-colorization'
face_image_generation = 'face-image-generation' face_image_generation = 'face-image-generation'
image_super_resolution = 'image-super-resolution' image_super_resolution = 'image-super-resolution'
style_transfer = 'style-transfer' style_transfer = 'style-transfer'


+ 42
- 0
tests/pipelines/test_image_colorization.py View File

@@ -0,0 +1,42 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import unittest

import cv2

from modelscope.msdatasets import MsDataset
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class ImageColorizationTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_unet_image-colorization'
self.test_image = 'data/test/images/marilyn_monroe_4.jpg'

def pipeline_inference(self, pipeline: Pipeline, test_image: str):
result = pipeline(test_image)
if result is not None:
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
print(f'Output written to {osp.abspath("result.png")}')

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_modelhub(self):
image_colorization = pipeline(
Tasks.image_colorization, model=self.model_id)

self.pipeline_inference(image_colorization, self.test_image)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
image_colorization = pipeline(Tasks.image_colorization)
self.pipeline_inference(image_colorization, self.test_image)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save