diff --git a/data/test/images/marilyn_monroe_4.jpg b/data/test/images/marilyn_monroe_4.jpg new file mode 100644 index 00000000..cdcf22b0 --- /dev/null +++ b/data/test/images/marilyn_monroe_4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8b425fb89442e4c6c32c71c17c1c1afef8a2c5bc9ec9529b5a0fc21c53e1a02b +size 39248 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 28f88cd5..a3e93296 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -50,6 +50,7 @@ class Pipelines(object): action_recognition = 'TAdaConv_action-recognition' animal_recognation = 'resnet101-animal_recog' cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' + image_colorization = 'unet-image-colorization' image_super_resolution = 'rrdb-image-super-resolution' face_image_generation = 'gan-face-image-generation' style_transfer = 'AAMS-style-transfer' diff --git a/modelscope/models/cv/image_colorization/unet.py b/modelscope/models/cv/image_colorization/unet.py new file mode 100644 index 00000000..8123651e --- /dev/null +++ b/modelscope/models/cv/image_colorization/unet.py @@ -0,0 +1,300 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import spectral_norm, weight_norm + +from .utils import (MergeLayer, NormType, PixelShuffle_ICNR, SelfAttention, + SequentialEx, SigmoidRange, dummy_eval, hook_outputs, + in_channels, model_sizes, relu, res_block) + +__all__ = ['DynamicUnetDeep', 'DynamicUnetWide'] + + +def custom_conv_layer( + ni, + nf, + ks=3, + stride=1, + padding=None, + bias=None, + is_1d=False, + norm_type=NormType.Batch, + use_activ=True, + leaky=None, + transpose=False, + init=nn.init.kaiming_normal_, + self_attention=False, + extra_bn=False, +): + 'Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers.' + if padding is None: + padding = (ks - 1) // 2 if not transpose else 0 + bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn is True + if bias is None: + bias = not bn + conv_func = nn.ConvTranspose2d if transpose is True else nn.Conv1d + conv_func = conv_func if is_1d else nn.Conv2d + conv = conv_func( + ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding) + if norm_type == NormType.Weight: + conv = weight_norm(conv) + elif norm_type == NormType.Spectral: + conv = spectral_norm(conv) + + layers = [conv] + if use_activ: + layers.append(relu(True, leaky=leaky)) + if bn: + layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf)) + if self_attention: + layers.append(SelfAttention(nf)) + return nn.Sequential(*layers) + + +def _get_sfs_idxs(sizes): + 'Get the indexes of the layers where the size of the activation changes.' + feature_szs = [size[-1] for size in sizes] + sfs_idxs = list( + np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]) + if feature_szs[0] != feature_szs[1]: + sfs_idxs = [0] + sfs_idxs + return sfs_idxs + + +class CustomPixelShuffle_ICNR(nn.Module): + 'Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, and `weight_norm`.' + + def __init__(self, ni, nf=None, scale=2, blur=False, leaky=None, **kwargs): + super().__init__() + nf = ni if nf is None else nf + self.conv = custom_conv_layer( + ni, nf * (scale**2), ks=1, use_activ=False, **kwargs) + self.shuf = nn.PixelShuffle(scale) + # Blurring over (h*w) kernel + # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts" + # - https://arxiv.org/abs/1806.02658 + self.pad = nn.ReplicationPad2d((1, 0, 1, 0)) + self.blur = nn.AvgPool2d(2, stride=1) + self.relu = relu(True, leaky=leaky) + + def forward(self, x): + x = self.shuf(self.relu(self.conv(x))) + return self.blur(self.pad(x)) if self.blur else x + + +class UnetBlockDeep(nn.Module): + 'A quasi-UNet block, using `PixelShuffle_ICNR upsampling`.' + + def __init__(self, + up_in_c, + x_in_c, + hook, + final_div=True, + blur=False, + leaky=None, + self_attention=False, + nf_factor=1.0, + **kwargs): + super().__init__() + self.hook = hook + self.shuf = CustomPixelShuffle_ICNR( + up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs) + self.bn = nn.BatchNorm2d(x_in_c) + ni = up_in_c // 2 + x_in_c + nf = int((ni if final_div else ni // 2) * nf_factor) + self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs) + self.conv2 = custom_conv_layer( + nf, nf, leaky=leaky, self_attention=self_attention, **kwargs) + self.relu = relu(leaky=leaky) + + def forward(self, up_in): + s = self.hook.stored + up_out = self.shuf(up_in) + ssh = s.shape[-2:] + if ssh != up_out.shape[-2:]: + up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest') + cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1)) + return self.conv2(self.conv1(cat_x)) + + +class DynamicUnetDeep(SequentialEx): + 'Create a U-Net from a given architecture.' + + def __init__(self, + encoder, + n_classes, + blur=False, + blur_final=True, + self_attention=False, + y_range=None, + last_cross=True, + bottle=False, + norm_type=NormType.Batch, + nf_factor=1.0, + **kwargs): + extra_bn = norm_type == NormType.Spectral + imsize = (256, 256) + sfs_szs = model_sizes(encoder, size=imsize) + sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs))) + self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False) + x = dummy_eval(encoder, imsize).detach() + + ni = sfs_szs[-1][1] + middle_conv = nn.Sequential( + custom_conv_layer( + ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs), + custom_conv_layer( + ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs), + ).eval() + x = middle_conv(x) + layers = [encoder, nn.BatchNorm2d(ni), nn.ReLU(), middle_conv] + + for i, idx in enumerate(sfs_idxs): + not_final = i != len(sfs_idxs) - 1 + up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1]) + sa = self_attention and (i == len(sfs_idxs) - 3) + unet_block = UnetBlockDeep( + up_in_c, + x_in_c, + self.sfs[i], + final_div=not_final, + blur=blur, + self_attention=sa, + norm_type=norm_type, + extra_bn=extra_bn, + nf_factor=nf_factor, + **kwargs).eval() + layers.append(unet_block) + x = unet_block(x) + + ni = x.shape[1] + if imsize != sfs_szs[0][-2:]: + layers.append(PixelShuffle_ICNR(ni, **kwargs)) + if last_cross: + layers.append(MergeLayer(dense=True)) + ni += in_channels(encoder) + layers.append( + res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs)) + layers += [ + custom_conv_layer( + ni, n_classes, ks=1, use_activ=False, norm_type=norm_type) + ] + if y_range is not None: + layers.append(SigmoidRange(*y_range)) + super().__init__(*layers) + + def __del__(self): + if hasattr(self, 'sfs'): + self.sfs.remove() + + +# ------------------------------------------------------ +class UnetBlockWide(nn.Module): + 'A quasi-UNet block, using `PixelShuffle_ICNR upsampling`.' + + def __init__(self, + up_in_c, + x_in_c, + n_out, + hook, + final_div=True, + blur=False, + leaky=None, + self_attention=False, + **kwargs): + super().__init__() + self.hook = hook + up_out = x_out = n_out // 2 + self.shuf = CustomPixelShuffle_ICNR( + up_in_c, up_out, blur=blur, leaky=leaky, **kwargs) + self.bn = nn.BatchNorm2d(x_in_c) + ni = up_out + x_in_c + self.conv = custom_conv_layer( + ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs) + self.relu = relu(leaky=leaky) + + def forward(self, up_in): + s = self.hook.stored + up_out = self.shuf(up_in) + ssh = s.shape[-2:] + if ssh != up_out.shape[-2:]: + up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest') + cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1)) + return self.conv(cat_x) + + +class DynamicUnetWide(SequentialEx): + 'Create a U-Net from a given architecture.' + + def __init__(self, + encoder, + n_classes, + blur=False, + blur_final=True, + self_attention=False, + y_range=None, + last_cross=True, + bottle=False, + norm_type=NormType.Batch, + nf_factor=1, + **kwargs): + + nf = 512 * nf_factor + extra_bn = norm_type == NormType.Spectral + imsize = (256, 256) + sfs_szs = model_sizes(encoder, size=imsize) + sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs))) + self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False) + x = dummy_eval(encoder, imsize).detach() + + ni = sfs_szs[-1][1] + middle_conv = nn.Sequential( + custom_conv_layer( + ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs), + custom_conv_layer( + ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs), + ).eval() + x = middle_conv(x) + layers = [encoder, nn.BatchNorm2d(ni), nn.ReLU(), middle_conv] + + for i, idx in enumerate(sfs_idxs): + not_final = i != len(sfs_idxs) - 1 + up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1]) + sa = self_attention and (i == len(sfs_idxs) - 3) + + n_out = nf if not_final else nf // 2 + + unet_block = UnetBlockWide( + up_in_c, + x_in_c, + n_out, + self.sfs[i], + final_div=not_final, + blur=blur, + self_attention=sa, + norm_type=norm_type, + extra_bn=extra_bn, + **kwargs).eval() + layers.append(unet_block) + x = unet_block(x) + + ni = x.shape[1] + if imsize != sfs_szs[0][-2:]: + layers.append(PixelShuffle_ICNR(ni, **kwargs)) + if last_cross: + layers.append(MergeLayer(dense=True)) + ni += in_channels(encoder) + layers.append( + res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs)) + layers += [ + custom_conv_layer( + ni, n_classes, ks=1, use_activ=False, norm_type=norm_type) + ] + if y_range is not None: + layers.append(SigmoidRange(*y_range)) + super().__init__(*layers) + + def __del__(self): + if hasattr(self, 'sfs'): + self.sfs.remove() diff --git a/modelscope/models/cv/image_colorization/utils.py b/modelscope/models/cv/image_colorization/utils.py new file mode 100644 index 00000000..03473f90 --- /dev/null +++ b/modelscope/models/cv/image_colorization/utils.py @@ -0,0 +1,348 @@ +import functools +from enum import Enum + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import spectral_norm, weight_norm + +NormType = Enum('NormType', + 'Batch BatchZero Weight Spectral Group Instance SpectralGN') + + +def is_listy(x): + return isinstance(x, (tuple, list)) + + +class Hook(): + 'Create a hook on `m` with `hook_func`.' + + def __init__(self, m, hook_func, is_forward=True, detach=True): + self.hook_func, self.detach, self.stored = hook_func, detach, None + f = m.register_forward_hook if is_forward else m.register_backward_hook + self.hook = f(self.hook_fn) + self.removed = False + + def hook_fn(self, module, input, output): + 'Applies `hook_func` to `module`, `input`, `output`.' + if self.detach: + input = (o.detach() + for o in input) if is_listy(input) else input.detach() + output = ( + o.detach() + for o in output) if is_listy(output) else output.detach() + self.stored = self.hook_func(module, input, output) + + def remove(self): + 'Remove the hook from the model.' + if not self.removed: + self.hook.remove() + self.removed = True + + def __enter__(self, *args): + return self + + def __exit__(self, *args): + self.remove() + + +class Hooks(): + 'Create several hooks on the modules in `ms` with `hook_func`.' + + def __init__(self, ms, hook_func, is_forward=True, detach=True): + self.hooks = [Hook(m, hook_func, is_forward, detach) for m in ms] + + def __getitem__(self, i): + return self.hooks[i] + + def __len__(self): + return len(self.hooks) + + def __iter__(self): + return iter(self.hooks) + + @property + def stored(self): + return [o.stored for o in self] + + def remove(self): + 'Remove the hooks from the model.' + for h in self.hooks: + h.remove() + + def __enter__(self, *args): + return self + + def __exit__(self, *args): + self.remove() + + +def _hook_inner(m, i, o): + return o if isinstance(o, torch.Tensor) else o if is_listy(o) else list(o) + + +def hook_outputs(modules, detach=True, grad=False): + 'Return `Hooks` that store activations of all `modules` in `self.stored`' + return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad) + + +def one_param(m): + 'Return the first parameter of `m`.' + return next(m.parameters()) + + +def dummy_batch(m, size=(64, 64)): + 'Create a dummy batch to go through `m` with `size`.' + ch_in = in_channels(m) + return one_param(m).new(1, ch_in, + *size).requires_grad_(False).uniform_(-1., 1.) + + +def dummy_eval(m, size=(64, 64)): + 'Pass a `dummy_batch` in evaluation mode in `m` with `size`.' + return m.eval()(dummy_batch(m, size)) + + +def model_sizes(m, size=(64, 64)): + 'Pass a dummy input through the model `m` to get the various sizes of activations.' + with hook_outputs(m) as hooks: + dummy_eval(m, size) + return [o.stored.shape for o in hooks] + + +class PrePostInitMeta(type): + 'A metaclass that calls optional `__pre_init__` and `__post_init__` methods' + + def __new__(cls, name, bases, dct): + x = super().__new__(cls, name, bases, dct) + old_init = x.__init__ + + def _pass(self): + pass + + @functools.wraps(old_init) + def _init(self, *args, **kwargs): + self.__pre_init__() + old_init(self, *args, **kwargs) + self.__post_init__() + + x.__init__ = _init + if not hasattr(x, '__pre_init__'): + x.__pre_init__ = _pass + if not hasattr(x, '__post_init__'): + x.__post_init__ = _pass + return x + + +class Module(nn.Module, metaclass=PrePostInitMeta): + 'Same as `nn.Module`, but no need for subclasses to call `super().__init__`' + + def __pre_init__(self): + super().__init__() + + def __init__(self): + pass + + +def children(m): + 'Get children of `m`.' + return list(m.children()) + + +def num_children(m): + 'Get number of children modules in `m`.' + return len(children(m)) + + +def children_and_parameters(m: nn.Module): + 'Return the children of `m` and its direct parameters not registered in modules.' + children = list(m.children()) + children_p = sum([[id(p) for p in c.parameters()] for c in m.children()], + []) + for p in m.parameters(): + if id(p) not in children_p: + children.append(ParameterModule(p)) + return children + + +def flatten_model(m): + if num_children(m): + mapped = map(flatten_model, children_and_parameters(m)) + return sum(mapped, []) + else: + return [m] + + +def in_channels(m): + 'Return the shape of the first weight layer in `m`.' + for layer in flatten_model(m): + if hasattr(layer, 'weight'): + return layer.weight.shape[1] + raise Exception('No weight layer') + + +def relu(inplace: bool = False, leaky: float = None): + 'Return a relu activation, maybe `leaky` and `inplace`.' + return nn.LeakyReLU( + inplace=inplace, + negative_slope=leaky) if leaky is not None else nn.ReLU( + inplace=inplace) + + +def conv_layer(ni, + nf, + ks=3, + stride=1, + padding=None, + bias=None, + is_1d=False, + norm_type=NormType.Batch, + use_activ=True, + leaky=None, + transpose=False, + init=nn.init.kaiming_normal_, + self_attention=False): + 'Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers.' + if padding is None: + padding = (ks - 1) // 2 if not transpose else 0 + bn = norm_type in (NormType.Batch, NormType.BatchZero) + if bias is None: + bias = not bn + conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d + conv = conv_func( + ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding) + if norm_type == NormType.Weight: + conv = weight_norm(conv) + elif norm_type == NormType.Spectral: + conv = spectral_norm(conv) + layers = [conv] + if use_activ: + layers.append(relu(True, leaky=leaky)) + if bn: + layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf)) + if self_attention: + layers.append(SelfAttention(nf)) + return nn.Sequential(*layers) + + +def res_block(nf, + dense=False, + norm_type=NormType.Batch, + bottle=False, + **conv_kwargs): + 'Resnet block of `nf` features. `conv_kwargs` are passed to `conv_layer`.' + norm2 = norm_type + if not dense and (norm_type == NormType.Batch): + norm2 = NormType.BatchZero + nf_inner = nf // 2 if bottle else nf + return SequentialEx( + conv_layer(nf, nf_inner, norm_type=norm_type, **conv_kwargs), + conv_layer(nf_inner, nf, norm_type=norm2, **conv_kwargs), + MergeLayer(dense)) + + +def conv1d(ni, no, ks=1, stride=1, padding=0, bias=False): + 'Create and initialize a `nn.Conv1d` layer with spectral normalization.' + conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias) + nn.init.kaiming_normal_(conv.weight) + if bias: + conv.bias.data.zero_() + return spectral_norm(conv) + + +class SelfAttention(Module): + 'Self attention layer for nd.' + + def __init__(self, n_channels): + self.query = conv1d(n_channels, n_channels // 8) + self.key = conv1d(n_channels, n_channels // 8) + self.value = conv1d(n_channels, n_channels) + self.gamma = nn.Parameter(torch.tensor([0.])) + + def forward(self, x): + 'Notation from https://arxiv.org/pdf/1805.08318.pdf' + size = x.size() + x = x.view(*size[:2], -1) + f, g, h = self.query(x), self.key(x), self.value(x) + beta = F.softmax(torch.bmm(f.permute(0, 2, 1).contiguous(), g), dim=1) + o = self.gamma * torch.bmm(h, beta) + x + return o.view(*size).contiguous() + + +def sigmoid_range(x, low, high): + 'Sigmoid function with range `(low, high)`' + return torch.sigmoid(x) * (high - low) + low + + +class SigmoidRange(Module): + 'Sigmoid module with range `(low,x_max)`' + + def __init__(self, low, high): + self.low, self.high = low, high + + def forward(self, x): + return sigmoid_range(x, self.low, self.high) + + +class SequentialEx(Module): + 'Like `nn.Sequential`, but with ModuleList semantics, and can access module input' + + def __init__(self, *layers): + self.layers = nn.ModuleList(layers) + + def forward(self, x): + res = x + for layer in self.layers: + res.orig = x + nres = layer(res) + res.orig = None + res = nres + return res + + def __getitem__(self, i): + return self.layers[i] + + def append(self, layer): + return self.layers.append(layer) + + def extend(self, layer): + return self.layers.extend(layer) + + def insert(self, i, layer): + return self.layers.insert(i, layer) + + +class MergeLayer(Module): + 'Merge a shortcut with the result of the module by adding them or concatenating thme if `dense=True`.' + + def __init__(self, dense: bool = False): + self.dense = dense + + def forward(self, x): + return torch.cat([x, x.orig], dim=1) if self.dense else (x + x.orig) + + +class PixelShuffle_ICNR(Module): + 'Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, and `weight_norm`.' + + def __init__(self, + ni: int, + nf: int = None, + scale: int = 2, + blur: bool = False, + norm_type=NormType.Weight, + leaky: float = None): + nf = ni if nf is None else nf + self.conv = conv_layer( + ni, nf * (scale**2), ks=1, norm_type=norm_type, use_activ=False) + self.shuf = nn.PixelShuffle(scale) + # Blurring over (h*w) kernel + # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts" + # - https://arxiv.org/abs/1806.02658 + self.pad = nn.ReplicationPad2d((1, 0, 1, 0)) + self.blur = nn.AvgPool2d(2, stride=1) + self.relu = relu(True, leaky=leaky) + + def forward(self, x): + x = self.shuf(self.relu(self.conv(x))) + return self.blur(self.pad(x)) if self.blur else x diff --git a/modelscope/outputs.py b/modelscope/outputs.py index abf7e323..eda56006 100644 --- a/modelscope/outputs.py +++ b/modelscope/outputs.py @@ -70,6 +70,7 @@ TASK_OUTPUTS = { Tasks.image_editing: [OutputKeys.OUTPUT_IMG], Tasks.image_matting: [OutputKeys.OUTPUT_IMG], Tasks.image_generation: [OutputKeys.OUTPUT_IMG], + Tasks.image_colorization: [OutputKeys.OUTPUT_IMG], Tasks.face_image_generation: [OutputKeys.OUTPUT_IMG], Tasks.image_super_resolution: [OutputKeys.OUTPUT_IMG], diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 4cb698f1..c008127f 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -68,6 +68,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.text_to_image_synthesis: (Pipelines.text_to_image_synthesis, 'damo/cv_imagen_text-to-image-synthesis_tiny'), + Tasks.image_colorization: (Pipelines.image_colorization, + 'damo/cv_unet_image-colorization'), Tasks.style_transfer: (Pipelines.style_transfer, 'damo/cv_aams_style-transfer_damo'), Tasks.face_image_generation: (Pipelines.face_image_generation, diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index b6b6dfa7..75a85da3 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -6,6 +6,7 @@ try: from .action_recognition_pipeline import ActionRecognitionPipeline from .animal_recog_pipeline import AnimalRecogPipeline from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline + from .image_colorization_pipeline import ImageColorizationPipeline from .image_super_resolution_pipeline import ImageSuperResolutionPipeline from .face_image_generation_pipeline import FaceImageGenerationPipeline except ModuleNotFoundError as e: diff --git a/modelscope/pipelines/cv/image_colorization_pipeline.py b/modelscope/pipelines/cv/image_colorization_pipeline.py new file mode 100644 index 00000000..5080b300 --- /dev/null +++ b/modelscope/pipelines/cv/image_colorization_pipeline.py @@ -0,0 +1,132 @@ +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch +from torchvision import models, transforms + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.image_colorization import unet +from modelscope.models.cv.image_colorization.utils import NormType +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from ..base import Pipeline +from ..builder import PIPELINES + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_colorization, module_name=Pipelines.image_colorization) +class ImageColorizationPipeline(Pipeline): + + def __init__(self, model: str): + """ + use `model` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + self.device = 'cuda' + self.cut = 8 + self.size = 1024 if self.device == 'cpu' else 512 + self.orig_img = None + self.model_type = 'stable' + self.norm = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + self.denorm = transforms.Normalize( + mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], + std=[1 / 0.229, 1 / 0.224, 1 / 0.225]) + + if self.model_type == 'stable': + body = models.resnet101(pretrained=True) + body = torch.nn.Sequential(*list(body.children())[:self.cut]) + self.model = unet.DynamicUnetWide( + body, + n_classes=3, + blur=True, + blur_final=True, + self_attention=True, + y_range=(-3.0, 3.0), + norm_type=NormType.Spectral, + last_cross=True, + bottle=False, + nf_factor=2, + ).to(self.device) + else: + body = models.resnet34(pretrained=True) + body = torch.nn.Sequential(*list(body.children())[:cut]) + model = unet.DynamicUnetDeep( + body, + n_classes=3, + blur=True, + blur_final=True, + self_attention=True, + y_range=(-3.0, 3.0), + norm_type=NormType.Spectral, + last_cross=True, + bottle=False, + nf_factor=1.5, + ).to(self.device) + + model_path = f'{model}/{ModelFile.TORCH_MODEL_FILE}' + self.model.load_state_dict( + torch.load(model_path)['model'], strict=True) + + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + if isinstance(input, str): + img = load_image(input).convert('LA').convert('RGB') + elif isinstance(input, PIL.Image.Image): + img = input.convert('LA').convert('RGB') + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + img = input[:, :, ::-1] # in rgb order + img = PIL.Image.fromarray(img).convert('LA').convert('RGB') + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + + self.wide, self.height = img.size + if self.wide * self.height > self.size * self.size: + self.orig_img = img.copy() + img = img.resize((self.size, self.size), + resample=PIL.Image.BILINEAR) + + img = self.norm(img).unsqueeze(0).to(self.device) + result = {'img': img} + + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + self.model.eval() + with torch.no_grad(): + out = self.model(input['img'])[0] + + out = self.denorm(out) + out = out.float().clamp(min=0, max=1) + out_img = (out.permute(1, 2, 0).flip(2).cpu().numpy() * 255).astype( + np.uint8) + + if self.orig_img is not None: + color_np = cv2.resize(out_img, self.orig_img.size) + orig_np = np.asarray(self.orig_img) + color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV) + orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV) + hires = np.copy(orig_yuv) + hires[:, :, 1:3] = color_yuv[:, :, 1:3] + out_img = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR) + + return {OutputKeys.OUTPUT_IMG: out_img.astype(np.uint8)} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 352f8ce8..adfa8b98 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -27,6 +27,7 @@ class CVTasks(object): ocr_detection = 'ocr-detection' action_recognition = 'action-recognition' video_embedding = 'video-embedding' + image_colorization = 'image-colorization' face_image_generation = 'face-image-generation' image_super_resolution = 'image-super-resolution' style_transfer = 'style-transfer' diff --git a/tests/pipelines/test_image_colorization.py b/tests/pipelines/test_image_colorization.py new file mode 100644 index 00000000..14090363 --- /dev/null +++ b/tests/pipelines/test_image_colorization.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import unittest + +import cv2 + +from modelscope.msdatasets import MsDataset +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class ImageColorizationTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_unet_image-colorization' + self.test_image = 'data/test/images/marilyn_monroe_4.jpg' + + def pipeline_inference(self, pipeline: Pipeline, test_image: str): + result = pipeline(test_image) + if result is not None: + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_modelhub(self): + image_colorization = pipeline( + Tasks.image_colorization, model=self.model_id) + + self.pipeline_inference(image_colorization, self.test_image) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + image_colorization = pipeline(Tasks.image_colorization) + self.pipeline_inference(image_colorization, self.test_image) + + +if __name__ == '__main__': + unittest.main()