Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9460900master
| @@ -0,0 +1,3 @@ | |||
| version https://git-lfs.github.com/spec/v1 | |||
| oid sha256:8b425fb89442e4c6c32c71c17c1c1afef8a2c5bc9ec9529b5a0fc21c53e1a02b | |||
| size 39248 | |||
| @@ -50,6 +50,7 @@ class Pipelines(object): | |||
| action_recognition = 'TAdaConv_action-recognition' | |||
| animal_recognation = 'resnet101-animal_recog' | |||
| cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | |||
| image_colorization = 'unet-image-colorization' | |||
| image_super_resolution = 'rrdb-image-super-resolution' | |||
| face_image_generation = 'gan-face-image-generation' | |||
| style_transfer = 'AAMS-style-transfer' | |||
| @@ -0,0 +1,300 @@ | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from torch.nn.utils import spectral_norm, weight_norm | |||
| from .utils import (MergeLayer, NormType, PixelShuffle_ICNR, SelfAttention, | |||
| SequentialEx, SigmoidRange, dummy_eval, hook_outputs, | |||
| in_channels, model_sizes, relu, res_block) | |||
| __all__ = ['DynamicUnetDeep', 'DynamicUnetWide'] | |||
| def custom_conv_layer( | |||
| ni, | |||
| nf, | |||
| ks=3, | |||
| stride=1, | |||
| padding=None, | |||
| bias=None, | |||
| is_1d=False, | |||
| norm_type=NormType.Batch, | |||
| use_activ=True, | |||
| leaky=None, | |||
| transpose=False, | |||
| init=nn.init.kaiming_normal_, | |||
| self_attention=False, | |||
| extra_bn=False, | |||
| ): | |||
| 'Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers.' | |||
| if padding is None: | |||
| padding = (ks - 1) // 2 if not transpose else 0 | |||
| bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn is True | |||
| if bias is None: | |||
| bias = not bn | |||
| conv_func = nn.ConvTranspose2d if transpose is True else nn.Conv1d | |||
| conv_func = conv_func if is_1d else nn.Conv2d | |||
| conv = conv_func( | |||
| ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding) | |||
| if norm_type == NormType.Weight: | |||
| conv = weight_norm(conv) | |||
| elif norm_type == NormType.Spectral: | |||
| conv = spectral_norm(conv) | |||
| layers = [conv] | |||
| if use_activ: | |||
| layers.append(relu(True, leaky=leaky)) | |||
| if bn: | |||
| layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf)) | |||
| if self_attention: | |||
| layers.append(SelfAttention(nf)) | |||
| return nn.Sequential(*layers) | |||
| def _get_sfs_idxs(sizes): | |||
| 'Get the indexes of the layers where the size of the activation changes.' | |||
| feature_szs = [size[-1] for size in sizes] | |||
| sfs_idxs = list( | |||
| np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0]) | |||
| if feature_szs[0] != feature_szs[1]: | |||
| sfs_idxs = [0] + sfs_idxs | |||
| return sfs_idxs | |||
| class CustomPixelShuffle_ICNR(nn.Module): | |||
| 'Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, and `weight_norm`.' | |||
| def __init__(self, ni, nf=None, scale=2, blur=False, leaky=None, **kwargs): | |||
| super().__init__() | |||
| nf = ni if nf is None else nf | |||
| self.conv = custom_conv_layer( | |||
| ni, nf * (scale**2), ks=1, use_activ=False, **kwargs) | |||
| self.shuf = nn.PixelShuffle(scale) | |||
| # Blurring over (h*w) kernel | |||
| # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts" | |||
| # - https://arxiv.org/abs/1806.02658 | |||
| self.pad = nn.ReplicationPad2d((1, 0, 1, 0)) | |||
| self.blur = nn.AvgPool2d(2, stride=1) | |||
| self.relu = relu(True, leaky=leaky) | |||
| def forward(self, x): | |||
| x = self.shuf(self.relu(self.conv(x))) | |||
| return self.blur(self.pad(x)) if self.blur else x | |||
| class UnetBlockDeep(nn.Module): | |||
| 'A quasi-UNet block, using `PixelShuffle_ICNR upsampling`.' | |||
| def __init__(self, | |||
| up_in_c, | |||
| x_in_c, | |||
| hook, | |||
| final_div=True, | |||
| blur=False, | |||
| leaky=None, | |||
| self_attention=False, | |||
| nf_factor=1.0, | |||
| **kwargs): | |||
| super().__init__() | |||
| self.hook = hook | |||
| self.shuf = CustomPixelShuffle_ICNR( | |||
| up_in_c, up_in_c // 2, blur=blur, leaky=leaky, **kwargs) | |||
| self.bn = nn.BatchNorm2d(x_in_c) | |||
| ni = up_in_c // 2 + x_in_c | |||
| nf = int((ni if final_div else ni // 2) * nf_factor) | |||
| self.conv1 = custom_conv_layer(ni, nf, leaky=leaky, **kwargs) | |||
| self.conv2 = custom_conv_layer( | |||
| nf, nf, leaky=leaky, self_attention=self_attention, **kwargs) | |||
| self.relu = relu(leaky=leaky) | |||
| def forward(self, up_in): | |||
| s = self.hook.stored | |||
| up_out = self.shuf(up_in) | |||
| ssh = s.shape[-2:] | |||
| if ssh != up_out.shape[-2:]: | |||
| up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest') | |||
| cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1)) | |||
| return self.conv2(self.conv1(cat_x)) | |||
| class DynamicUnetDeep(SequentialEx): | |||
| 'Create a U-Net from a given architecture.' | |||
| def __init__(self, | |||
| encoder, | |||
| n_classes, | |||
| blur=False, | |||
| blur_final=True, | |||
| self_attention=False, | |||
| y_range=None, | |||
| last_cross=True, | |||
| bottle=False, | |||
| norm_type=NormType.Batch, | |||
| nf_factor=1.0, | |||
| **kwargs): | |||
| extra_bn = norm_type == NormType.Spectral | |||
| imsize = (256, 256) | |||
| sfs_szs = model_sizes(encoder, size=imsize) | |||
| sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs))) | |||
| self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False) | |||
| x = dummy_eval(encoder, imsize).detach() | |||
| ni = sfs_szs[-1][1] | |||
| middle_conv = nn.Sequential( | |||
| custom_conv_layer( | |||
| ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs), | |||
| custom_conv_layer( | |||
| ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs), | |||
| ).eval() | |||
| x = middle_conv(x) | |||
| layers = [encoder, nn.BatchNorm2d(ni), nn.ReLU(), middle_conv] | |||
| for i, idx in enumerate(sfs_idxs): | |||
| not_final = i != len(sfs_idxs) - 1 | |||
| up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1]) | |||
| sa = self_attention and (i == len(sfs_idxs) - 3) | |||
| unet_block = UnetBlockDeep( | |||
| up_in_c, | |||
| x_in_c, | |||
| self.sfs[i], | |||
| final_div=not_final, | |||
| blur=blur, | |||
| self_attention=sa, | |||
| norm_type=norm_type, | |||
| extra_bn=extra_bn, | |||
| nf_factor=nf_factor, | |||
| **kwargs).eval() | |||
| layers.append(unet_block) | |||
| x = unet_block(x) | |||
| ni = x.shape[1] | |||
| if imsize != sfs_szs[0][-2:]: | |||
| layers.append(PixelShuffle_ICNR(ni, **kwargs)) | |||
| if last_cross: | |||
| layers.append(MergeLayer(dense=True)) | |||
| ni += in_channels(encoder) | |||
| layers.append( | |||
| res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs)) | |||
| layers += [ | |||
| custom_conv_layer( | |||
| ni, n_classes, ks=1, use_activ=False, norm_type=norm_type) | |||
| ] | |||
| if y_range is not None: | |||
| layers.append(SigmoidRange(*y_range)) | |||
| super().__init__(*layers) | |||
| def __del__(self): | |||
| if hasattr(self, 'sfs'): | |||
| self.sfs.remove() | |||
| # ------------------------------------------------------ | |||
| class UnetBlockWide(nn.Module): | |||
| 'A quasi-UNet block, using `PixelShuffle_ICNR upsampling`.' | |||
| def __init__(self, | |||
| up_in_c, | |||
| x_in_c, | |||
| n_out, | |||
| hook, | |||
| final_div=True, | |||
| blur=False, | |||
| leaky=None, | |||
| self_attention=False, | |||
| **kwargs): | |||
| super().__init__() | |||
| self.hook = hook | |||
| up_out = x_out = n_out // 2 | |||
| self.shuf = CustomPixelShuffle_ICNR( | |||
| up_in_c, up_out, blur=blur, leaky=leaky, **kwargs) | |||
| self.bn = nn.BatchNorm2d(x_in_c) | |||
| ni = up_out + x_in_c | |||
| self.conv = custom_conv_layer( | |||
| ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs) | |||
| self.relu = relu(leaky=leaky) | |||
| def forward(self, up_in): | |||
| s = self.hook.stored | |||
| up_out = self.shuf(up_in) | |||
| ssh = s.shape[-2:] | |||
| if ssh != up_out.shape[-2:]: | |||
| up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest') | |||
| cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1)) | |||
| return self.conv(cat_x) | |||
| class DynamicUnetWide(SequentialEx): | |||
| 'Create a U-Net from a given architecture.' | |||
| def __init__(self, | |||
| encoder, | |||
| n_classes, | |||
| blur=False, | |||
| blur_final=True, | |||
| self_attention=False, | |||
| y_range=None, | |||
| last_cross=True, | |||
| bottle=False, | |||
| norm_type=NormType.Batch, | |||
| nf_factor=1, | |||
| **kwargs): | |||
| nf = 512 * nf_factor | |||
| extra_bn = norm_type == NormType.Spectral | |||
| imsize = (256, 256) | |||
| sfs_szs = model_sizes(encoder, size=imsize) | |||
| sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs))) | |||
| self.sfs = hook_outputs([encoder[i] for i in sfs_idxs], detach=False) | |||
| x = dummy_eval(encoder, imsize).detach() | |||
| ni = sfs_szs[-1][1] | |||
| middle_conv = nn.Sequential( | |||
| custom_conv_layer( | |||
| ni, ni * 2, norm_type=norm_type, extra_bn=extra_bn, **kwargs), | |||
| custom_conv_layer( | |||
| ni * 2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs), | |||
| ).eval() | |||
| x = middle_conv(x) | |||
| layers = [encoder, nn.BatchNorm2d(ni), nn.ReLU(), middle_conv] | |||
| for i, idx in enumerate(sfs_idxs): | |||
| not_final = i != len(sfs_idxs) - 1 | |||
| up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1]) | |||
| sa = self_attention and (i == len(sfs_idxs) - 3) | |||
| n_out = nf if not_final else nf // 2 | |||
| unet_block = UnetBlockWide( | |||
| up_in_c, | |||
| x_in_c, | |||
| n_out, | |||
| self.sfs[i], | |||
| final_div=not_final, | |||
| blur=blur, | |||
| self_attention=sa, | |||
| norm_type=norm_type, | |||
| extra_bn=extra_bn, | |||
| **kwargs).eval() | |||
| layers.append(unet_block) | |||
| x = unet_block(x) | |||
| ni = x.shape[1] | |||
| if imsize != sfs_szs[0][-2:]: | |||
| layers.append(PixelShuffle_ICNR(ni, **kwargs)) | |||
| if last_cross: | |||
| layers.append(MergeLayer(dense=True)) | |||
| ni += in_channels(encoder) | |||
| layers.append( | |||
| res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs)) | |||
| layers += [ | |||
| custom_conv_layer( | |||
| ni, n_classes, ks=1, use_activ=False, norm_type=norm_type) | |||
| ] | |||
| if y_range is not None: | |||
| layers.append(SigmoidRange(*y_range)) | |||
| super().__init__(*layers) | |||
| def __del__(self): | |||
| if hasattr(self, 'sfs'): | |||
| self.sfs.remove() | |||
| @@ -0,0 +1,348 @@ | |||
| import functools | |||
| from enum import Enum | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from torch.nn.utils import spectral_norm, weight_norm | |||
| NormType = Enum('NormType', | |||
| 'Batch BatchZero Weight Spectral Group Instance SpectralGN') | |||
| def is_listy(x): | |||
| return isinstance(x, (tuple, list)) | |||
| class Hook(): | |||
| 'Create a hook on `m` with `hook_func`.' | |||
| def __init__(self, m, hook_func, is_forward=True, detach=True): | |||
| self.hook_func, self.detach, self.stored = hook_func, detach, None | |||
| f = m.register_forward_hook if is_forward else m.register_backward_hook | |||
| self.hook = f(self.hook_fn) | |||
| self.removed = False | |||
| def hook_fn(self, module, input, output): | |||
| 'Applies `hook_func` to `module`, `input`, `output`.' | |||
| if self.detach: | |||
| input = (o.detach() | |||
| for o in input) if is_listy(input) else input.detach() | |||
| output = ( | |||
| o.detach() | |||
| for o in output) if is_listy(output) else output.detach() | |||
| self.stored = self.hook_func(module, input, output) | |||
| def remove(self): | |||
| 'Remove the hook from the model.' | |||
| if not self.removed: | |||
| self.hook.remove() | |||
| self.removed = True | |||
| def __enter__(self, *args): | |||
| return self | |||
| def __exit__(self, *args): | |||
| self.remove() | |||
| class Hooks(): | |||
| 'Create several hooks on the modules in `ms` with `hook_func`.' | |||
| def __init__(self, ms, hook_func, is_forward=True, detach=True): | |||
| self.hooks = [Hook(m, hook_func, is_forward, detach) for m in ms] | |||
| def __getitem__(self, i): | |||
| return self.hooks[i] | |||
| def __len__(self): | |||
| return len(self.hooks) | |||
| def __iter__(self): | |||
| return iter(self.hooks) | |||
| @property | |||
| def stored(self): | |||
| return [o.stored for o in self] | |||
| def remove(self): | |||
| 'Remove the hooks from the model.' | |||
| for h in self.hooks: | |||
| h.remove() | |||
| def __enter__(self, *args): | |||
| return self | |||
| def __exit__(self, *args): | |||
| self.remove() | |||
| def _hook_inner(m, i, o): | |||
| return o if isinstance(o, torch.Tensor) else o if is_listy(o) else list(o) | |||
| def hook_outputs(modules, detach=True, grad=False): | |||
| 'Return `Hooks` that store activations of all `modules` in `self.stored`' | |||
| return Hooks(modules, _hook_inner, detach=detach, is_forward=not grad) | |||
| def one_param(m): | |||
| 'Return the first parameter of `m`.' | |||
| return next(m.parameters()) | |||
| def dummy_batch(m, size=(64, 64)): | |||
| 'Create a dummy batch to go through `m` with `size`.' | |||
| ch_in = in_channels(m) | |||
| return one_param(m).new(1, ch_in, | |||
| *size).requires_grad_(False).uniform_(-1., 1.) | |||
| def dummy_eval(m, size=(64, 64)): | |||
| 'Pass a `dummy_batch` in evaluation mode in `m` with `size`.' | |||
| return m.eval()(dummy_batch(m, size)) | |||
| def model_sizes(m, size=(64, 64)): | |||
| 'Pass a dummy input through the model `m` to get the various sizes of activations.' | |||
| with hook_outputs(m) as hooks: | |||
| dummy_eval(m, size) | |||
| return [o.stored.shape for o in hooks] | |||
| class PrePostInitMeta(type): | |||
| 'A metaclass that calls optional `__pre_init__` and `__post_init__` methods' | |||
| def __new__(cls, name, bases, dct): | |||
| x = super().__new__(cls, name, bases, dct) | |||
| old_init = x.__init__ | |||
| def _pass(self): | |||
| pass | |||
| @functools.wraps(old_init) | |||
| def _init(self, *args, **kwargs): | |||
| self.__pre_init__() | |||
| old_init(self, *args, **kwargs) | |||
| self.__post_init__() | |||
| x.__init__ = _init | |||
| if not hasattr(x, '__pre_init__'): | |||
| x.__pre_init__ = _pass | |||
| if not hasattr(x, '__post_init__'): | |||
| x.__post_init__ = _pass | |||
| return x | |||
| class Module(nn.Module, metaclass=PrePostInitMeta): | |||
| 'Same as `nn.Module`, but no need for subclasses to call `super().__init__`' | |||
| def __pre_init__(self): | |||
| super().__init__() | |||
| def __init__(self): | |||
| pass | |||
| def children(m): | |||
| 'Get children of `m`.' | |||
| return list(m.children()) | |||
| def num_children(m): | |||
| 'Get number of children modules in `m`.' | |||
| return len(children(m)) | |||
| def children_and_parameters(m: nn.Module): | |||
| 'Return the children of `m` and its direct parameters not registered in modules.' | |||
| children = list(m.children()) | |||
| children_p = sum([[id(p) for p in c.parameters()] for c in m.children()], | |||
| []) | |||
| for p in m.parameters(): | |||
| if id(p) not in children_p: | |||
| children.append(ParameterModule(p)) | |||
| return children | |||
| def flatten_model(m): | |||
| if num_children(m): | |||
| mapped = map(flatten_model, children_and_parameters(m)) | |||
| return sum(mapped, []) | |||
| else: | |||
| return [m] | |||
| def in_channels(m): | |||
| 'Return the shape of the first weight layer in `m`.' | |||
| for layer in flatten_model(m): | |||
| if hasattr(layer, 'weight'): | |||
| return layer.weight.shape[1] | |||
| raise Exception('No weight layer') | |||
| def relu(inplace: bool = False, leaky: float = None): | |||
| 'Return a relu activation, maybe `leaky` and `inplace`.' | |||
| return nn.LeakyReLU( | |||
| inplace=inplace, | |||
| negative_slope=leaky) if leaky is not None else nn.ReLU( | |||
| inplace=inplace) | |||
| def conv_layer(ni, | |||
| nf, | |||
| ks=3, | |||
| stride=1, | |||
| padding=None, | |||
| bias=None, | |||
| is_1d=False, | |||
| norm_type=NormType.Batch, | |||
| use_activ=True, | |||
| leaky=None, | |||
| transpose=False, | |||
| init=nn.init.kaiming_normal_, | |||
| self_attention=False): | |||
| 'Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers.' | |||
| if padding is None: | |||
| padding = (ks - 1) // 2 if not transpose else 0 | |||
| bn = norm_type in (NormType.Batch, NormType.BatchZero) | |||
| if bias is None: | |||
| bias = not bn | |||
| conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d | |||
| conv = conv_func( | |||
| ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding) | |||
| if norm_type == NormType.Weight: | |||
| conv = weight_norm(conv) | |||
| elif norm_type == NormType.Spectral: | |||
| conv = spectral_norm(conv) | |||
| layers = [conv] | |||
| if use_activ: | |||
| layers.append(relu(True, leaky=leaky)) | |||
| if bn: | |||
| layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf)) | |||
| if self_attention: | |||
| layers.append(SelfAttention(nf)) | |||
| return nn.Sequential(*layers) | |||
| def res_block(nf, | |||
| dense=False, | |||
| norm_type=NormType.Batch, | |||
| bottle=False, | |||
| **conv_kwargs): | |||
| 'Resnet block of `nf` features. `conv_kwargs` are passed to `conv_layer`.' | |||
| norm2 = norm_type | |||
| if not dense and (norm_type == NormType.Batch): | |||
| norm2 = NormType.BatchZero | |||
| nf_inner = nf // 2 if bottle else nf | |||
| return SequentialEx( | |||
| conv_layer(nf, nf_inner, norm_type=norm_type, **conv_kwargs), | |||
| conv_layer(nf_inner, nf, norm_type=norm2, **conv_kwargs), | |||
| MergeLayer(dense)) | |||
| def conv1d(ni, no, ks=1, stride=1, padding=0, bias=False): | |||
| 'Create and initialize a `nn.Conv1d` layer with spectral normalization.' | |||
| conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias) | |||
| nn.init.kaiming_normal_(conv.weight) | |||
| if bias: | |||
| conv.bias.data.zero_() | |||
| return spectral_norm(conv) | |||
| class SelfAttention(Module): | |||
| 'Self attention layer for nd.' | |||
| def __init__(self, n_channels): | |||
| self.query = conv1d(n_channels, n_channels // 8) | |||
| self.key = conv1d(n_channels, n_channels // 8) | |||
| self.value = conv1d(n_channels, n_channels) | |||
| self.gamma = nn.Parameter(torch.tensor([0.])) | |||
| def forward(self, x): | |||
| 'Notation from https://arxiv.org/pdf/1805.08318.pdf' | |||
| size = x.size() | |||
| x = x.view(*size[:2], -1) | |||
| f, g, h = self.query(x), self.key(x), self.value(x) | |||
| beta = F.softmax(torch.bmm(f.permute(0, 2, 1).contiguous(), g), dim=1) | |||
| o = self.gamma * torch.bmm(h, beta) + x | |||
| return o.view(*size).contiguous() | |||
| def sigmoid_range(x, low, high): | |||
| 'Sigmoid function with range `(low, high)`' | |||
| return torch.sigmoid(x) * (high - low) + low | |||
| class SigmoidRange(Module): | |||
| 'Sigmoid module with range `(low,x_max)`' | |||
| def __init__(self, low, high): | |||
| self.low, self.high = low, high | |||
| def forward(self, x): | |||
| return sigmoid_range(x, self.low, self.high) | |||
| class SequentialEx(Module): | |||
| 'Like `nn.Sequential`, but with ModuleList semantics, and can access module input' | |||
| def __init__(self, *layers): | |||
| self.layers = nn.ModuleList(layers) | |||
| def forward(self, x): | |||
| res = x | |||
| for layer in self.layers: | |||
| res.orig = x | |||
| nres = layer(res) | |||
| res.orig = None | |||
| res = nres | |||
| return res | |||
| def __getitem__(self, i): | |||
| return self.layers[i] | |||
| def append(self, layer): | |||
| return self.layers.append(layer) | |||
| def extend(self, layer): | |||
| return self.layers.extend(layer) | |||
| def insert(self, i, layer): | |||
| return self.layers.insert(i, layer) | |||
| class MergeLayer(Module): | |||
| 'Merge a shortcut with the result of the module by adding them or concatenating thme if `dense=True`.' | |||
| def __init__(self, dense: bool = False): | |||
| self.dense = dense | |||
| def forward(self, x): | |||
| return torch.cat([x, x.orig], dim=1) if self.dense else (x + x.orig) | |||
| class PixelShuffle_ICNR(Module): | |||
| 'Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, and `weight_norm`.' | |||
| def __init__(self, | |||
| ni: int, | |||
| nf: int = None, | |||
| scale: int = 2, | |||
| blur: bool = False, | |||
| norm_type=NormType.Weight, | |||
| leaky: float = None): | |||
| nf = ni if nf is None else nf | |||
| self.conv = conv_layer( | |||
| ni, nf * (scale**2), ks=1, norm_type=norm_type, use_activ=False) | |||
| self.shuf = nn.PixelShuffle(scale) | |||
| # Blurring over (h*w) kernel | |||
| # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts" | |||
| # - https://arxiv.org/abs/1806.02658 | |||
| self.pad = nn.ReplicationPad2d((1, 0, 1, 0)) | |||
| self.blur = nn.AvgPool2d(2, stride=1) | |||
| self.relu = relu(True, leaky=leaky) | |||
| def forward(self, x): | |||
| x = self.shuf(self.relu(self.conv(x))) | |||
| return self.blur(self.pad(x)) if self.blur else x | |||
| @@ -70,6 +70,7 @@ TASK_OUTPUTS = { | |||
| Tasks.image_editing: [OutputKeys.OUTPUT_IMG], | |||
| Tasks.image_matting: [OutputKeys.OUTPUT_IMG], | |||
| Tasks.image_generation: [OutputKeys.OUTPUT_IMG], | |||
| Tasks.image_colorization: [OutputKeys.OUTPUT_IMG], | |||
| Tasks.face_image_generation: [OutputKeys.OUTPUT_IMG], | |||
| Tasks.image_super_resolution: [OutputKeys.OUTPUT_IMG], | |||
| @@ -68,6 +68,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||
| Tasks.text_to_image_synthesis: | |||
| (Pipelines.text_to_image_synthesis, | |||
| 'damo/cv_imagen_text-to-image-synthesis_tiny'), | |||
| Tasks.image_colorization: (Pipelines.image_colorization, | |||
| 'damo/cv_unet_image-colorization'), | |||
| Tasks.style_transfer: (Pipelines.style_transfer, | |||
| 'damo/cv_aams_style-transfer_damo'), | |||
| Tasks.face_image_generation: (Pipelines.face_image_generation, | |||
| @@ -6,6 +6,7 @@ try: | |||
| from .action_recognition_pipeline import ActionRecognitionPipeline | |||
| from .animal_recog_pipeline import AnimalRecogPipeline | |||
| from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | |||
| from .image_colorization_pipeline import ImageColorizationPipeline | |||
| from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | |||
| from .face_image_generation_pipeline import FaceImageGenerationPipeline | |||
| except ModuleNotFoundError as e: | |||
| @@ -0,0 +1,132 @@ | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from torchvision import models, transforms | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.image_colorization import unet | |||
| from modelscope.models.cv.image_colorization.utils import NormType | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from ..base import Pipeline | |||
| from ..builder import PIPELINES | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_colorization, module_name=Pipelines.image_colorization) | |||
| class ImageColorizationPipeline(Pipeline): | |||
| def __init__(self, model: str): | |||
| """ | |||
| use `model` to create a kws pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model) | |||
| self.device = 'cuda' | |||
| self.cut = 8 | |||
| self.size = 1024 if self.device == 'cpu' else 512 | |||
| self.orig_img = None | |||
| self.model_type = 'stable' | |||
| self.norm = transforms.Compose([ | |||
| transforms.ToTensor(), | |||
| transforms.Normalize( | |||
| mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |||
| ]) | |||
| self.denorm = transforms.Normalize( | |||
| mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], | |||
| std=[1 / 0.229, 1 / 0.224, 1 / 0.225]) | |||
| if self.model_type == 'stable': | |||
| body = models.resnet101(pretrained=True) | |||
| body = torch.nn.Sequential(*list(body.children())[:self.cut]) | |||
| self.model = unet.DynamicUnetWide( | |||
| body, | |||
| n_classes=3, | |||
| blur=True, | |||
| blur_final=True, | |||
| self_attention=True, | |||
| y_range=(-3.0, 3.0), | |||
| norm_type=NormType.Spectral, | |||
| last_cross=True, | |||
| bottle=False, | |||
| nf_factor=2, | |||
| ).to(self.device) | |||
| else: | |||
| body = models.resnet34(pretrained=True) | |||
| body = torch.nn.Sequential(*list(body.children())[:cut]) | |||
| model = unet.DynamicUnetDeep( | |||
| body, | |||
| n_classes=3, | |||
| blur=True, | |||
| blur_final=True, | |||
| self_attention=True, | |||
| y_range=(-3.0, 3.0), | |||
| norm_type=NormType.Spectral, | |||
| last_cross=True, | |||
| bottle=False, | |||
| nf_factor=1.5, | |||
| ).to(self.device) | |||
| model_path = f'{model}/{ModelFile.TORCH_MODEL_FILE}' | |||
| self.model.load_state_dict( | |||
| torch.load(model_path)['model'], strict=True) | |||
| logger.info('load model done') | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| if isinstance(input, str): | |||
| img = load_image(input).convert('LA').convert('RGB') | |||
| elif isinstance(input, PIL.Image.Image): | |||
| img = input.convert('LA').convert('RGB') | |||
| elif isinstance(input, np.ndarray): | |||
| if len(input.shape) == 2: | |||
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |||
| img = input[:, :, ::-1] # in rgb order | |||
| img = PIL.Image.fromarray(img).convert('LA').convert('RGB') | |||
| else: | |||
| raise TypeError(f'input should be either str, PIL.Image,' | |||
| f' np.array, but got {type(input)}') | |||
| self.wide, self.height = img.size | |||
| if self.wide * self.height > self.size * self.size: | |||
| self.orig_img = img.copy() | |||
| img = img.resize((self.size, self.size), | |||
| resample=PIL.Image.BILINEAR) | |||
| img = self.norm(img).unsqueeze(0).to(self.device) | |||
| result = {'img': img} | |||
| return result | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| self.model.eval() | |||
| with torch.no_grad(): | |||
| out = self.model(input['img'])[0] | |||
| out = self.denorm(out) | |||
| out = out.float().clamp(min=0, max=1) | |||
| out_img = (out.permute(1, 2, 0).flip(2).cpu().numpy() * 255).astype( | |||
| np.uint8) | |||
| if self.orig_img is not None: | |||
| color_np = cv2.resize(out_img, self.orig_img.size) | |||
| orig_np = np.asarray(self.orig_img) | |||
| color_yuv = cv2.cvtColor(color_np, cv2.COLOR_BGR2YUV) | |||
| orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_BGR2YUV) | |||
| hires = np.copy(orig_yuv) | |||
| hires[:, :, 1:3] = color_yuv[:, :, 1:3] | |||
| out_img = cv2.cvtColor(hires, cv2.COLOR_YUV2BGR) | |||
| return {OutputKeys.OUTPUT_IMG: out_img.astype(np.uint8)} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -27,6 +27,7 @@ class CVTasks(object): | |||
| ocr_detection = 'ocr-detection' | |||
| action_recognition = 'action-recognition' | |||
| video_embedding = 'video-embedding' | |||
| image_colorization = 'image-colorization' | |||
| face_image_generation = 'face-image-generation' | |||
| image_super_resolution = 'image-super-resolution' | |||
| style_transfer = 'style-transfer' | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import unittest | |||
| import cv2 | |||
| from modelscope.msdatasets import MsDataset | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class ImageColorizationTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_unet_image-colorization' | |||
| self.test_image = 'data/test/images/marilyn_monroe_4.jpg' | |||
| def pipeline_inference(self, pipeline: Pipeline, test_image: str): | |||
| result = pipeline(test_image) | |||
| if result is not None: | |||
| cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) | |||
| print(f'Output written to {osp.abspath("result.png")}') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| image_colorization = pipeline( | |||
| Tasks.image_colorization, model=self.model_id) | |||
| self.pipeline_inference(image_colorization, self.test_image) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_modelhub_default_model(self): | |||
| image_colorization = pipeline(Tasks.image_colorization) | |||
| self.pipeline_inference(image_colorization, self.test_image) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||