添加了基于stylegan2的人像生成算法
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9426478
master
| @@ -49,6 +49,7 @@ class Pipelines(object): | |||
| action_recognition = 'TAdaConv_action-recognition' | |||
| animal_recognation = 'resnet101-animal_recog' | |||
| cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' | |||
| face_image_generation = 'gan-face-image-generation' | |||
| style_transfer = 'AAMS-style-transfer' | |||
| # nlp tasks | |||
| @@ -0,0 +1,2 @@ | |||
| from .fused_act import FusedLeakyReLU, fused_leaky_relu | |||
| from .upfirdn2d import upfirdn2d | |||
| @@ -0,0 +1,229 @@ | |||
| import contextlib | |||
| import warnings | |||
| import torch | |||
| from torch import autograd | |||
| from torch.nn import functional as F | |||
| enabled = True | |||
| weight_gradients_disabled = False | |||
| @contextlib.contextmanager | |||
| def no_weight_gradients(): | |||
| global weight_gradients_disabled | |||
| old = weight_gradients_disabled | |||
| weight_gradients_disabled = True | |||
| yield | |||
| weight_gradients_disabled = old | |||
| def conv2d(input, | |||
| weight, | |||
| bias=None, | |||
| stride=1, | |||
| padding=0, | |||
| dilation=1, | |||
| groups=1): | |||
| if could_use_op(input): | |||
| return conv2d_gradfix( | |||
| transpose=False, | |||
| weight_shape=weight.shape, | |||
| stride=stride, | |||
| padding=padding, | |||
| output_padding=0, | |||
| dilation=dilation, | |||
| groups=groups, | |||
| ).apply(input, weight, bias) | |||
| return F.conv2d( | |||
| input=input, | |||
| weight=weight, | |||
| bias=bias, | |||
| stride=stride, | |||
| padding=padding, | |||
| dilation=dilation, | |||
| groups=groups, | |||
| ) | |||
| def conv_transpose2d( | |||
| input, | |||
| weight, | |||
| bias=None, | |||
| stride=1, | |||
| padding=0, | |||
| output_padding=0, | |||
| groups=1, | |||
| dilation=1, | |||
| ): | |||
| if could_use_op(input): | |||
| return conv2d_gradfix( | |||
| transpose=True, | |||
| weight_shape=weight.shape, | |||
| stride=stride, | |||
| padding=padding, | |||
| output_padding=output_padding, | |||
| groups=groups, | |||
| dilation=dilation, | |||
| ).apply(input, weight, bias) | |||
| return F.conv_transpose2d( | |||
| input=input, | |||
| weight=weight, | |||
| bias=bias, | |||
| stride=stride, | |||
| padding=padding, | |||
| output_padding=output_padding, | |||
| dilation=dilation, | |||
| groups=groups, | |||
| ) | |||
| def could_use_op(input): | |||
| if (not enabled) or (not torch.backends.cudnn.enabled): | |||
| return False | |||
| if input.device.type != 'cuda': | |||
| return False | |||
| if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.']): | |||
| return True | |||
| warnings.warn( | |||
| f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().' | |||
| ) | |||
| return False | |||
| def ensure_tuple(xs, ndim): | |||
| xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs, ) * ndim | |||
| return xs | |||
| conv2d_gradfix_cache = dict() | |||
| def conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, | |||
| dilation, groups): | |||
| ndim = 2 | |||
| weight_shape = tuple(weight_shape) | |||
| stride = ensure_tuple(stride, ndim) | |||
| padding = ensure_tuple(padding, ndim) | |||
| output_padding = ensure_tuple(output_padding, ndim) | |||
| dilation = ensure_tuple(dilation, ndim) | |||
| key = (transpose, weight_shape, stride, padding, output_padding, dilation, | |||
| groups) | |||
| if key in conv2d_gradfix_cache: | |||
| return conv2d_gradfix_cache[key] | |||
| common_kwargs = dict( | |||
| stride=stride, padding=padding, dilation=dilation, groups=groups) | |||
| def calc_output_padding(input_shape, output_shape): | |||
| if transpose: | |||
| return [0, 0] | |||
| a = input_shape[i + 2] - (output_shape[i + 2] - 1) * stride[i] | |||
| return [ | |||
| a - (1 - 2 * padding[i]) - dilation[i] * (weight_shape[i + 2] - 1) | |||
| for i in range(ndim) | |||
| ] | |||
| class Conv2d(autograd.Function): | |||
| @staticmethod | |||
| def forward(ctx, input, weight, bias): | |||
| if not transpose: | |||
| out = F.conv2d( | |||
| input=input, weight=weight, bias=bias, **common_kwargs) | |||
| else: | |||
| out = F.conv_transpose2d( | |||
| input=input, | |||
| weight=weight, | |||
| bias=bias, | |||
| output_padding=output_padding, | |||
| **common_kwargs, | |||
| ) | |||
| ctx.save_for_backward(input, weight) | |||
| return out | |||
| @staticmethod | |||
| def backward(ctx, grad_output): | |||
| input, weight = ctx.saved_tensors | |||
| grad_input, grad_weight, grad_bias = None, None, None | |||
| if ctx.needs_input_grad[0]: | |||
| p = calc_output_padding( | |||
| input_shape=input.shape, output_shape=grad_output.shape) | |||
| grad_input = conv2d_gradfix( | |||
| transpose=(not transpose), | |||
| weight_shape=weight_shape, | |||
| output_padding=p, | |||
| **common_kwargs, | |||
| ).apply(grad_output, weight, None) | |||
| if ctx.needs_input_grad[1] and not weight_gradients_disabled: | |||
| grad_weight = Conv2dGradWeight.apply(grad_output, input) | |||
| if ctx.needs_input_grad[2]: | |||
| grad_bias = grad_output.sum((0, 2, 3)) | |||
| return grad_input, grad_weight, grad_bias | |||
| class Conv2dGradWeight(autograd.Function): | |||
| @staticmethod | |||
| def forward(ctx, grad_output, input): | |||
| op = torch._C._jit_get_operation( | |||
| 'aten::cudnn_convolution_backward_weight' if not transpose else | |||
| 'aten::cudnn_convolution_transpose_backward_weight') | |||
| flags = [ | |||
| torch.backends.cudnn.benchmark, | |||
| torch.backends.cudnn.deterministic, | |||
| torch.backends.cudnn.allow_tf32, | |||
| ] | |||
| grad_weight = op( | |||
| weight_shape, | |||
| grad_output, | |||
| input, | |||
| padding, | |||
| stride, | |||
| dilation, | |||
| groups, | |||
| *flags, | |||
| ) | |||
| ctx.save_for_backward(grad_output, input) | |||
| return grad_weight | |||
| @staticmethod | |||
| def backward(ctx, grad_grad_weight): | |||
| grad_output, input = ctx.saved_tensors | |||
| grad_grad_output, grad_grad_input = None, None | |||
| if ctx.needs_input_grad[0]: | |||
| grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) | |||
| if ctx.needs_input_grad[1]: | |||
| p = calc_output_padding( | |||
| input_shape=input.shape, output_shape=grad_output.shape) | |||
| grad_grad_input = conv2d_gradfix( | |||
| transpose=(not transpose), | |||
| weight_shape=weight_shape, | |||
| output_padding=p, | |||
| **common_kwargs, | |||
| ).apply(grad_output, grad_grad_weight, None) | |||
| return grad_grad_output, grad_grad_input | |||
| conv2d_gradfix_cache[key] = Conv2d | |||
| return Conv2d | |||
| @@ -0,0 +1,113 @@ | |||
| import os | |||
| import torch | |||
| from torch import nn | |||
| from torch.autograd import Function | |||
| from torch.nn import functional as F | |||
| def_lib = False | |||
| class FusedLeakyReLUFunctionBackward(Function): | |||
| @staticmethod | |||
| def forward(ctx, grad_output, out, bias, negative_slope, scale): | |||
| ctx.save_for_backward(out) | |||
| ctx.negative_slope = negative_slope | |||
| ctx.scale = scale | |||
| empty = grad_output.new_empty(0) | |||
| grad_input = fused.fused_bias_act(grad_output.contiguous(), empty, out, | |||
| 3, 1, negative_slope, scale) | |||
| dim = [0] | |||
| if grad_input.ndim > 2: | |||
| dim += list(range(2, grad_input.ndim)) | |||
| if bias: | |||
| grad_bias = grad_input.sum(dim).detach() | |||
| else: | |||
| grad_bias = empty | |||
| return grad_input, grad_bias | |||
| @staticmethod | |||
| def backward(ctx, gradgrad_input, gradgrad_bias): | |||
| out, = ctx.saved_tensors | |||
| gradgrad_out = fused.fused_bias_act( | |||
| gradgrad_input.contiguous(), | |||
| gradgrad_bias, | |||
| out, | |||
| 3, | |||
| 1, | |||
| ctx.negative_slope, | |||
| ctx.scale, | |||
| ) | |||
| return gradgrad_out, None, None, None, None | |||
| class FusedLeakyReLUFunction(Function): | |||
| @staticmethod | |||
| def forward(ctx, input, bias, negative_slope, scale): | |||
| empty = input.new_empty(0) | |||
| ctx.bias = bias is not None | |||
| if bias is None: | |||
| bias = empty | |||
| out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, | |||
| scale) | |||
| ctx.save_for_backward(out) | |||
| ctx.negative_slope = negative_slope | |||
| ctx.scale = scale | |||
| return out | |||
| @staticmethod | |||
| def backward(ctx, grad_output): | |||
| out, = ctx.saved_tensors | |||
| grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( | |||
| grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale) | |||
| if not ctx.bias: | |||
| grad_bias = None | |||
| return grad_input, grad_bias, None, None | |||
| class FusedLeakyReLU(nn.Module): | |||
| def __init__(self, channel, bias=True, negative_slope=0.2, scale=2**0.5): | |||
| super().__init__() | |||
| if bias: | |||
| self.bias = nn.Parameter(torch.zeros(channel)) | |||
| else: | |||
| self.bias = None | |||
| self.negative_slope = negative_slope | |||
| self.scale = scale | |||
| def forward(self, input): | |||
| return fused_leaky_relu(input, self.bias, self.negative_slope, | |||
| self.scale) | |||
| def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5): | |||
| if not def_lib: | |||
| if bias is not None: | |||
| rest_dim = [1] * (input.ndim - bias.ndim - 1) | |||
| return (F.leaky_relu( | |||
| input + bias.view(1, bias.shape[0], *rest_dim), | |||
| negative_slope=0.2) * scale) | |||
| else: | |||
| return F.leaky_relu(input, negative_slope=0.2) * scale | |||
| @@ -0,0 +1,197 @@ | |||
| import os | |||
| from collections import abc | |||
| import torch | |||
| from torch.autograd import Function | |||
| from torch.nn import functional as F | |||
| def_lib = False | |||
| class UpFirDn2dBackward(Function): | |||
| @staticmethod | |||
| def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, | |||
| in_size, out_size): | |||
| up_x, up_y = up | |||
| down_x, down_y = down | |||
| g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad | |||
| grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) | |||
| grad_input = upfirdn2d_op.upfirdn2d( | |||
| grad_output, | |||
| grad_kernel, | |||
| down_x, | |||
| down_y, | |||
| up_x, | |||
| up_y, | |||
| g_pad_x0, | |||
| g_pad_x1, | |||
| g_pad_y0, | |||
| g_pad_y1, | |||
| ) | |||
| grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], | |||
| in_size[3]) | |||
| ctx.save_for_backward(kernel) | |||
| pad_x0, pad_x1, pad_y0, pad_y1 = pad | |||
| ctx.up_x = up_x | |||
| ctx.up_y = up_y | |||
| ctx.down_x = down_x | |||
| ctx.down_y = down_y | |||
| ctx.pad_x0 = pad_x0 | |||
| ctx.pad_x1 = pad_x1 | |||
| ctx.pad_y0 = pad_y0 | |||
| ctx.pad_y1 = pad_y1 | |||
| ctx.in_size = in_size | |||
| ctx.out_size = out_size | |||
| return grad_input | |||
| @staticmethod | |||
| def backward(ctx, gradgrad_input): | |||
| kernel, = ctx.saved_tensors | |||
| gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], | |||
| ctx.in_size[3], 1) | |||
| gradgrad_out = upfirdn2d_op.upfirdn2d( | |||
| gradgrad_input, | |||
| kernel, | |||
| ctx.up_x, | |||
| ctx.up_y, | |||
| ctx.down_x, | |||
| ctx.down_y, | |||
| ctx.pad_x0, | |||
| ctx.pad_x1, | |||
| ctx.pad_y0, | |||
| ctx.pad_y1, | |||
| ) | |||
| # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) | |||
| gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], | |||
| ctx.out_size[0], ctx.out_size[1]) | |||
| return gradgrad_out, None, None, None, None, None, None, None, None | |||
| class UpFirDn2d(Function): | |||
| @staticmethod | |||
| def forward(ctx, input, kernel, up, down, pad): | |||
| up_x, up_y = up | |||
| down_x, down_y = down | |||
| pad_x0, pad_x1, pad_y0, pad_y1 = pad | |||
| kernel_h, kernel_w = kernel.shape | |||
| batch, channel, in_h, in_w = input.shape | |||
| ctx.in_size = input.shape | |||
| input = input.reshape(-1, in_h, in_w, 1) | |||
| ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) | |||
| out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y | |||
| out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x | |||
| ctx.out_size = (out_h, out_w) | |||
| ctx.up = (up_x, up_y) | |||
| ctx.down = (down_x, down_y) | |||
| ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) | |||
| g_pad_x0 = kernel_w - pad_x0 - 1 | |||
| g_pad_y0 = kernel_h - pad_y0 - 1 | |||
| g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 | |||
| g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 | |||
| ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) | |||
| out = upfirdn2d_op.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, | |||
| pad_x0, pad_x1, pad_y0, pad_y1) | |||
| # out = out.view(major, out_h, out_w, minor) | |||
| out = out.view(-1, channel, out_h, out_w) | |||
| return out | |||
| @staticmethod | |||
| def backward(ctx, grad_output): | |||
| kernel, grad_kernel = ctx.saved_tensors | |||
| grad_input = None | |||
| if ctx.needs_input_grad[0]: | |||
| grad_input = UpFirDn2dBackward.apply( | |||
| grad_output, | |||
| kernel, | |||
| grad_kernel, | |||
| ctx.up, | |||
| ctx.down, | |||
| ctx.pad, | |||
| ctx.g_pad, | |||
| ctx.in_size, | |||
| ctx.out_size, | |||
| ) | |||
| return grad_input, None, None, None, None | |||
| def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): | |||
| if not isinstance(up, abc.Iterable): | |||
| up = (up, up) | |||
| if not isinstance(down, abc.Iterable): | |||
| down = (down, down) | |||
| if len(pad) == 2: | |||
| pad = (pad[0], pad[1], pad[0], pad[1]) | |||
| if not def_lib: | |||
| out = upfirdn2d_native(input, kernel, *up, *down, *pad) | |||
| return out | |||
| def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, | |||
| pad_y0, pad_y1): | |||
| _, channel, in_h, in_w = input.shape | |||
| input = input.reshape(-1, in_h, in_w, 1) | |||
| _, in_h, in_w, minor = input.shape | |||
| kernel_h, kernel_w = kernel.shape | |||
| out = input.view(-1, in_h, 1, in_w, 1, minor) | |||
| out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) | |||
| out = out.view(-1, in_h * up_y, in_w * up_x, minor) | |||
| out = F.pad( | |||
| out, | |||
| [0, 0, | |||
| max(pad_x0, 0), | |||
| max(pad_x1, 0), | |||
| max(pad_y0, 0), | |||
| max(pad_y1, 0)]) | |||
| out = out[:, | |||
| max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), | |||
| max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0)] | |||
| out = out.permute(0, 3, 1, 2) | |||
| out = out.reshape( | |||
| [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) | |||
| w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) | |||
| out = F.conv2d(out, w) | |||
| out = out.reshape( | |||
| -1, | |||
| minor, | |||
| in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, | |||
| in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, | |||
| ) | |||
| out = out.permute(0, 2, 3, 1) | |||
| out = out[:, ::down_y, ::down_x, :] | |||
| out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y | |||
| out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x | |||
| return out.view(-1, channel, out_h, out_w) | |||
| @@ -0,0 +1,731 @@ | |||
| import functools | |||
| import math | |||
| import operator | |||
| import random | |||
| import torch | |||
| from torch import nn | |||
| from torch.autograd import Function | |||
| from torch.nn import functional as F | |||
| from .op import FusedLeakyReLU, conv2d_gradfix, fused_leaky_relu, upfirdn2d | |||
| class PixelNorm(nn.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| def forward(self, input): | |||
| return input * torch.rsqrt( | |||
| torch.mean(input**2, dim=1, keepdim=True) + 1e-8) | |||
| def make_kernel(k): | |||
| k = torch.tensor(k, dtype=torch.float32) | |||
| if k.ndim == 1: | |||
| k = k[None, :] * k[:, None] | |||
| k /= k.sum() | |||
| return k | |||
| class Upsample(nn.Module): | |||
| def __init__(self, kernel, factor=2): | |||
| super().__init__() | |||
| self.factor = factor | |||
| kernel = make_kernel(kernel) * (factor**2) | |||
| self.register_buffer('kernel', kernel) | |||
| p = kernel.shape[0] - factor | |||
| pad0 = (p + 1) // 2 + factor - 1 | |||
| pad1 = p // 2 | |||
| self.pad = (pad0, pad1) | |||
| def forward(self, input): | |||
| out = upfirdn2d( | |||
| input, self.kernel, up=self.factor, down=1, pad=self.pad) | |||
| return out | |||
| class Downsample(nn.Module): | |||
| def __init__(self, kernel, factor=2): | |||
| super().__init__() | |||
| self.factor = factor | |||
| kernel = make_kernel(kernel) | |||
| self.register_buffer('kernel', kernel) | |||
| p = kernel.shape[0] - factor | |||
| pad0 = (p + 1) // 2 | |||
| pad1 = p // 2 | |||
| self.pad = (pad0, pad1) | |||
| def forward(self, input): | |||
| out = upfirdn2d( | |||
| input, self.kernel, up=1, down=self.factor, pad=self.pad) | |||
| return out | |||
| class Blur(nn.Module): | |||
| def __init__(self, kernel, pad, upsample_factor=1): | |||
| super().__init__() | |||
| kernel = make_kernel(kernel) | |||
| if upsample_factor > 1: | |||
| kernel = kernel * (upsample_factor**2) | |||
| self.register_buffer('kernel', kernel) | |||
| self.pad = pad | |||
| def forward(self, input): | |||
| out = upfirdn2d(input, self.kernel, pad=self.pad) | |||
| return out | |||
| class EqualConv2d(nn.Module): | |||
| def __init__(self, | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size, | |||
| stride=1, | |||
| padding=0, | |||
| bias=True): | |||
| super().__init__() | |||
| self.weight = nn.Parameter( | |||
| torch.randn(out_channel, in_channel, kernel_size, kernel_size)) | |||
| self.scale = 1 / math.sqrt(in_channel * kernel_size**2) | |||
| self.stride = stride | |||
| self.padding = padding | |||
| if bias: | |||
| self.bias = nn.Parameter(torch.zeros(out_channel)) | |||
| else: | |||
| self.bias = None | |||
| def forward(self, input): | |||
| out = conv2d_gradfix.conv2d( | |||
| input, | |||
| self.weight * self.scale, | |||
| bias=self.bias, | |||
| stride=self.stride, | |||
| padding=self.padding, | |||
| ) | |||
| return out | |||
| def __repr__(self): | |||
| return ( | |||
| f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' | |||
| f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' | |||
| ) | |||
| class EqualLinear(nn.Module): | |||
| def __init__(self, | |||
| in_dim, | |||
| out_dim, | |||
| bias=True, | |||
| bias_init=0, | |||
| lr_mul=1, | |||
| activation=None): | |||
| super().__init__() | |||
| self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) | |||
| if bias: | |||
| self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) | |||
| else: | |||
| self.bias = None | |||
| self.activation = activation | |||
| self.scale = (1 / math.sqrt(in_dim)) * lr_mul | |||
| self.lr_mul = lr_mul | |||
| def forward(self, input): | |||
| if self.activation: | |||
| out = F.linear(input, self.weight * self.scale) | |||
| out = fused_leaky_relu(out, self.bias * self.lr_mul) | |||
| else: | |||
| out = F.linear( | |||
| input, self.weight * self.scale, bias=self.bias * self.lr_mul) | |||
| return out | |||
| def __repr__(self): | |||
| return ( | |||
| f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' | |||
| ) | |||
| class ModulatedConv2d(nn.Module): | |||
| def __init__( | |||
| self, | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size, | |||
| style_dim, | |||
| demodulate=True, | |||
| upsample=False, | |||
| downsample=False, | |||
| blur_kernel=[1, 3, 3, 1], | |||
| fused=True, | |||
| ): | |||
| super().__init__() | |||
| self.eps = 1e-8 | |||
| self.kernel_size = kernel_size | |||
| self.in_channel = in_channel | |||
| self.out_channel = out_channel | |||
| self.upsample = upsample | |||
| self.downsample = downsample | |||
| if upsample: | |||
| factor = 2 | |||
| p = (len(blur_kernel) - factor) - (kernel_size - 1) | |||
| pad0 = (p + 1) // 2 + factor - 1 | |||
| pad1 = p // 2 + 1 | |||
| self.blur = Blur( | |||
| blur_kernel, pad=(pad0, pad1), upsample_factor=factor) | |||
| if downsample: | |||
| factor = 2 | |||
| p = (len(blur_kernel) - factor) + (kernel_size - 1) | |||
| pad0 = (p + 1) // 2 | |||
| pad1 = p // 2 | |||
| self.blur = Blur(blur_kernel, pad=(pad0, pad1)) | |||
| fan_in = in_channel * kernel_size**2 | |||
| self.scale = 1 / math.sqrt(fan_in) | |||
| self.padding = kernel_size // 2 | |||
| self.weight = nn.Parameter( | |||
| torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)) | |||
| self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) | |||
| self.demodulate = demodulate | |||
| self.fused = fused | |||
| def __repr__(self): | |||
| return ( | |||
| f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' | |||
| f'upsample={self.upsample}, downsample={self.downsample})') | |||
| def forward(self, input, style): | |||
| batch, in_channel, height, width = input.shape | |||
| if not self.fused: | |||
| weight = self.scale * self.weight.squeeze(0) | |||
| style = self.modulation(style) | |||
| if self.demodulate: | |||
| w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, | |||
| 1) | |||
| dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt() | |||
| input = input * style.reshape(batch, in_channel, 1, 1) | |||
| if self.upsample: | |||
| weight = weight.transpose(0, 1) | |||
| out = conv2d_gradfix.conv_transpose2d( | |||
| input, weight, padding=0, stride=2) | |||
| out = self.blur(out) | |||
| elif self.downsample: | |||
| input = self.blur(input) | |||
| out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2) | |||
| else: | |||
| out = conv2d_gradfix.conv2d( | |||
| input, weight, padding=self.padding) | |||
| if self.demodulate: | |||
| out = out * dcoefs.view(batch, -1, 1, 1) | |||
| return out | |||
| style = self.modulation(style).view(batch, 1, in_channel, 1, 1) | |||
| weight = self.scale * self.weight * style | |||
| if self.demodulate: | |||
| demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) | |||
| weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) | |||
| weight = weight.view(batch * self.out_channel, in_channel, | |||
| self.kernel_size, self.kernel_size) | |||
| if self.upsample: | |||
| input = input.view(1, batch * in_channel, height, width) | |||
| weight = weight.view(batch, self.out_channel, in_channel, | |||
| self.kernel_size, self.kernel_size) | |||
| weight = weight.transpose(1, 2).reshape(batch * in_channel, | |||
| self.out_channel, | |||
| self.kernel_size, | |||
| self.kernel_size) | |||
| out = conv2d_gradfix.conv_transpose2d( | |||
| input, weight, padding=0, stride=2, groups=batch) | |||
| _, _, height, width = out.shape | |||
| out = out.view(batch, self.out_channel, height, width) | |||
| out = self.blur(out) | |||
| elif self.downsample: | |||
| input = self.blur(input) | |||
| _, _, height, width = input.shape | |||
| input = input.view(1, batch * in_channel, height, width) | |||
| out = conv2d_gradfix.conv2d( | |||
| input, weight, padding=0, stride=2, groups=batch) | |||
| _, _, height, width = out.shape | |||
| out = out.view(batch, self.out_channel, height, width) | |||
| else: | |||
| input = input.view(1, batch * in_channel, height, width) | |||
| out = conv2d_gradfix.conv2d( | |||
| input, weight, padding=self.padding, groups=batch) | |||
| _, _, height, width = out.shape | |||
| out = out.view(batch, self.out_channel, height, width) | |||
| return out | |||
| class NoiseInjection(nn.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.weight = nn.Parameter(torch.zeros(1)) | |||
| def forward(self, image, noise=None): | |||
| if noise is None: | |||
| batch, _, height, width = image.shape | |||
| noise = image.new_empty(batch, 1, height, width).normal_() | |||
| return image + self.weight * noise | |||
| class ConstantInput(nn.Module): | |||
| def __init__(self, channel, size=4): | |||
| super().__init__() | |||
| self.input = nn.Parameter(torch.randn(1, channel, size, size)) | |||
| def forward(self, input): | |||
| batch = input.shape[0] | |||
| out = self.input.repeat(batch, 1, 1, 1) | |||
| return out | |||
| class StyledConv(nn.Module): | |||
| def __init__( | |||
| self, | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size, | |||
| style_dim, | |||
| upsample=False, | |||
| blur_kernel=[1, 3, 3, 1], | |||
| demodulate=True, | |||
| ): | |||
| super().__init__() | |||
| self.conv = ModulatedConv2d( | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size, | |||
| style_dim, | |||
| upsample=upsample, | |||
| blur_kernel=blur_kernel, | |||
| demodulate=demodulate, | |||
| ) | |||
| self.noise = NoiseInjection() | |||
| # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) | |||
| # self.activate = ScaledLeakyReLU(0.2) | |||
| self.activate = FusedLeakyReLU(out_channel) | |||
| def forward(self, input, style, noise=None): | |||
| out = self.conv(input, style) | |||
| out = self.noise(out, noise=noise) | |||
| # out = out + self.bias | |||
| out = self.activate(out) | |||
| return out | |||
| class ToRGB(nn.Module): | |||
| def __init__(self, | |||
| in_channel, | |||
| style_dim, | |||
| upsample=True, | |||
| blur_kernel=[1, 3, 3, 1]): | |||
| super().__init__() | |||
| if upsample: | |||
| self.upsample = Upsample(blur_kernel) | |||
| self.conv = ModulatedConv2d( | |||
| in_channel, 3, 1, style_dim, demodulate=False) | |||
| self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) | |||
| def forward(self, input, style, skip=None): | |||
| out = self.conv(input, style) | |||
| out = out + self.bias | |||
| if skip is not None: | |||
| skip = self.upsample(skip) | |||
| out = out + skip | |||
| return out | |||
| class Generator(nn.Module): | |||
| def __init__( | |||
| self, | |||
| size, | |||
| style_dim, | |||
| n_mlp, | |||
| channel_multiplier=2, | |||
| blur_kernel=[1, 3, 3, 1], | |||
| lr_mlp=0.01, | |||
| ): | |||
| super().__init__() | |||
| self.size = size | |||
| self.style_dim = style_dim | |||
| layers = [PixelNorm()] | |||
| for i in range(n_mlp): | |||
| layers.append( | |||
| EqualLinear( | |||
| style_dim, | |||
| style_dim, | |||
| lr_mul=lr_mlp, | |||
| activation='fused_lrelu')) | |||
| self.style = nn.Sequential(*layers) | |||
| self.channels = { | |||
| 4: 512, | |||
| 8: 512, | |||
| 16: 512, | |||
| 32: 512, | |||
| 64: 256 * channel_multiplier, | |||
| 128: 128 * channel_multiplier, | |||
| 256: 64 * channel_multiplier, | |||
| 512: 32 * channel_multiplier, | |||
| 1024: 16 * channel_multiplier, | |||
| } | |||
| self.input = ConstantInput(self.channels[4]) | |||
| self.conv1 = StyledConv( | |||
| self.channels[4], | |||
| self.channels[4], | |||
| 3, | |||
| style_dim, | |||
| blur_kernel=blur_kernel) | |||
| self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) | |||
| self.log_size = int(math.log(size, 2)) | |||
| self.num_layers = (self.log_size - 2) * 2 + 1 | |||
| self.convs = nn.ModuleList() | |||
| self.upsamples = nn.ModuleList() | |||
| self.to_rgbs = nn.ModuleList() | |||
| self.noises = nn.Module() | |||
| in_channel = self.channels[4] | |||
| for layer_idx in range(self.num_layers): | |||
| res = (layer_idx + 5) // 2 | |||
| shape = [1, 1, 2**res, 2**res] | |||
| self.noises.register_buffer(f'noise_{layer_idx}', | |||
| torch.randn(*shape)) | |||
| for i in range(3, self.log_size + 1): | |||
| out_channel = self.channels[2**i] | |||
| self.convs.append( | |||
| StyledConv( | |||
| in_channel, | |||
| out_channel, | |||
| 3, | |||
| style_dim, | |||
| upsample=True, | |||
| blur_kernel=blur_kernel, | |||
| )) | |||
| self.convs.append( | |||
| StyledConv( | |||
| out_channel, | |||
| out_channel, | |||
| 3, | |||
| style_dim, | |||
| blur_kernel=blur_kernel)) | |||
| self.to_rgbs.append(ToRGB(out_channel, style_dim)) | |||
| in_channel = out_channel | |||
| self.n_latent = self.log_size * 2 - 2 | |||
| def make_noise(self): | |||
| device = self.input.input.device | |||
| noises = [torch.randn(1, 1, 2**2, 2**2, device=device)] | |||
| for i in range(3, self.log_size + 1): | |||
| for _ in range(2): | |||
| noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) | |||
| return noises | |||
| def mean_latent(self, n_latent): | |||
| latent_in = torch.randn( | |||
| n_latent, self.style_dim, device=self.input.input.device) | |||
| latent = self.style(latent_in).mean(0, keepdim=True) | |||
| return latent | |||
| def get_latent(self, input): | |||
| return self.style(input) | |||
| def forward( | |||
| self, | |||
| styles, | |||
| return_latents=False, | |||
| inject_index=None, | |||
| truncation=1, | |||
| truncation_latent=None, | |||
| input_is_latent=False, | |||
| noise=None, | |||
| randomize_noise=True, | |||
| ): | |||
| if not input_is_latent: | |||
| styles = [self.style(s) for s in styles] | |||
| if noise is None: | |||
| if randomize_noise: | |||
| noise = [None] * self.num_layers | |||
| else: | |||
| noise = [ | |||
| getattr(self.noises, f'noise_{i}') | |||
| for i in range(self.num_layers) | |||
| ] | |||
| if truncation < 1: | |||
| style_t = [] | |||
| for style in styles: | |||
| style_t.append(truncation_latent | |||
| + truncation * (style - truncation_latent)) | |||
| styles = style_t | |||
| if len(styles) < 2: | |||
| inject_index = self.n_latent | |||
| if styles[0].ndim < 3: | |||
| latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) | |||
| else: | |||
| latent = styles[0] | |||
| else: | |||
| if inject_index is None: | |||
| inject_index = random.randint(1, self.n_latent - 1) | |||
| latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) | |||
| latent2 = styles[1].unsqueeze(1).repeat( | |||
| 1, self.n_latent - inject_index, 1) | |||
| latent = torch.cat([latent, latent2], 1) | |||
| out = self.input(latent) | |||
| out = self.conv1(out, latent[:, 0], noise=noise[0]) | |||
| skip = self.to_rgb1(out, latent[:, 1]) | |||
| i = 1 | |||
| for conv1, conv2, noise1, noise2, to_rgb in zip( | |||
| self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], | |||
| self.to_rgbs): | |||
| out = conv1(out, latent[:, i], noise=noise1) | |||
| out = conv2(out, latent[:, i + 1], noise=noise2) | |||
| skip = to_rgb(out, latent[:, i + 2], skip) | |||
| i += 2 | |||
| image = skip | |||
| if return_latents: | |||
| return image, latent | |||
| else: | |||
| return image, None | |||
| class ConvLayer(nn.Sequential): | |||
| def __init__( | |||
| self, | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size, | |||
| downsample=False, | |||
| blur_kernel=[1, 3, 3, 1], | |||
| bias=True, | |||
| activate=True, | |||
| ): | |||
| layers = [] | |||
| if downsample: | |||
| factor = 2 | |||
| p = (len(blur_kernel) - factor) + (kernel_size - 1) | |||
| pad0 = (p + 1) // 2 | |||
| pad1 = p // 2 | |||
| layers.append(Blur(blur_kernel, pad=(pad0, pad1))) | |||
| stride = 2 | |||
| self.padding = 0 | |||
| else: | |||
| stride = 1 | |||
| self.padding = kernel_size // 2 | |||
| layers.append( | |||
| EqualConv2d( | |||
| in_channel, | |||
| out_channel, | |||
| kernel_size, | |||
| padding=self.padding, | |||
| stride=stride, | |||
| bias=bias and not activate, | |||
| )) | |||
| if activate: | |||
| layers.append(FusedLeakyReLU(out_channel, bias=bias)) | |||
| super().__init__(*layers) | |||
| class ResBlock(nn.Module): | |||
| def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): | |||
| super().__init__() | |||
| self.conv1 = ConvLayer(in_channel, in_channel, 3) | |||
| self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) | |||
| self.skip = ConvLayer( | |||
| in_channel, | |||
| out_channel, | |||
| 1, | |||
| downsample=True, | |||
| activate=False, | |||
| bias=False) | |||
| def forward(self, input): | |||
| out = self.conv1(input) | |||
| out = self.conv2(out) | |||
| skip = self.skip(input) | |||
| out = (out + skip) / math.sqrt(2) | |||
| return out | |||
| class Discriminator(nn.Module): | |||
| def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): | |||
| super().__init__() | |||
| channels = { | |||
| 4: 512, | |||
| 8: 512, | |||
| 16: 512, | |||
| 32: 512, | |||
| 64: 256 * channel_multiplier, | |||
| 128: 128 * channel_multiplier, | |||
| 256: 64 * channel_multiplier, | |||
| 512: 32 * channel_multiplier, | |||
| 1024: 16 * channel_multiplier, | |||
| } | |||
| convs = [ConvLayer(3, channels[size], 1)] | |||
| log_size = int(math.log(size, 2)) | |||
| in_channel = channels[size] | |||
| for i in range(log_size, 2, -1): | |||
| out_channel = channels[2**(i - 1)] | |||
| convs.append(ResBlock(in_channel, out_channel, blur_kernel)) | |||
| in_channel = out_channel | |||
| self.convs = nn.Sequential(*convs) | |||
| self.stddev_group = 4 | |||
| self.stddev_feat = 1 | |||
| self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) | |||
| self.final_linear = nn.Sequential( | |||
| EqualLinear( | |||
| channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), | |||
| EqualLinear(channels[4], 1), | |||
| ) | |||
| def forward(self, input): | |||
| out = self.convs(input) | |||
| batch, channel, height, width = out.shape | |||
| group = min(batch, self.stddev_group) | |||
| stddev = out.view(group, -1, self.stddev_feat, | |||
| channel // self.stddev_feat, height, width) | |||
| stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) | |||
| stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) | |||
| stddev = stddev.repeat(group, 1, height, width) | |||
| out = torch.cat([out, stddev], 1) | |||
| out = self.final_conv(out) | |||
| out = out.view(batch, -1) | |||
| out = self.final_linear(out) | |||
| return out | |||
| @@ -6,6 +6,7 @@ try: | |||
| from .action_recognition_pipeline import ActionRecognitionPipeline | |||
| from .animal_recog_pipeline import AnimalRecogPipeline | |||
| from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline | |||
| from .face_image_generation_pipeline import FaceImageGenerationPipeline | |||
| except ModuleNotFoundError as e: | |||
| if str(e) == "No module named 'torch'": | |||
| pass | |||
| @@ -0,0 +1,79 @@ | |||
| import os | |||
| from typing import Any, Dict | |||
| import cv2 | |||
| import numpy as np | |||
| import PIL | |||
| import torch | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.cv.face_generation import stylegan2 | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines.base import Input | |||
| from modelscope.preprocessors import load_image | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from ..base import Pipeline | |||
| from ..builder import PIPELINES | |||
| logger = get_logger() | |||
| @PIPELINES.register_module( | |||
| Tasks.image_generation, module_name=Pipelines.face_image_generation) | |||
| class FaceImageGenerationPipeline(Pipeline): | |||
| def __init__(self, model: str): | |||
| """ | |||
| use `model` and `preprocessor` to create a kws pipeline for prediction | |||
| Args: | |||
| model: model id on modelscope hub. | |||
| """ | |||
| super().__init__(model=model) | |||
| self.size = 1024 | |||
| self.latent = 512 | |||
| self.n_mlp = 8 | |||
| self.channel_multiplier = 2 | |||
| self.truncation = 0.7 | |||
| self.truncation_mean = 4096 | |||
| self.generator = stylegan2.Generator( | |||
| self.size, | |||
| self.latent, | |||
| self.n_mlp, | |||
| channel_multiplier=self.channel_multiplier) | |||
| self.model_file = f'{model}/{ModelFile.TORCH_MODEL_FILE}' | |||
| self.generator.load_state_dict(torch.load(self.model_file)['g_ema']) | |||
| logger.info('load model done') | |||
| self.mean_latent = None | |||
| if self.truncation < 1: | |||
| with torch.no_grad(): | |||
| self.mean_latent = self.generator.mean_latent( | |||
| self.truncation_mean) | |||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||
| return input | |||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||
| assert isinstance(input, int) | |||
| torch.manual_seed(input) | |||
| torch.cuda.manual_seed(input) | |||
| torch.cuda.manual_seed_all(input) | |||
| self.generator.eval() | |||
| with torch.no_grad(): | |||
| sample_z = torch.randn(1, self.latent) | |||
| sample, _ = self.generator([sample_z], | |||
| truncation=self.truncation, | |||
| truncation_latent=self.mean_latent) | |||
| sample = sample * 0.5 + 0.5 | |||
| sample = sample.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR | |||
| sample = np.clip(sample.float().cpu().numpy(), 0, 1) * 255 | |||
| return {OutputKeys.OUTPUT_IMG: sample.astype(np.uint8)} | |||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||
| return inputs | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import os.path as osp | |||
| import unittest | |||
| import cv2 | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.test_utils import test_level | |||
| class FaceGenerationTest(unittest.TestCase): | |||
| def setUp(self) -> None: | |||
| self.model_id = 'damo/cv_gan_face-image-generation' | |||
| def pipeline_inference(self, pipeline: Pipeline, seed: int): | |||
| result = pipeline(seed) | |||
| if result is not None: | |||
| cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) | |||
| print(f'Output written to {osp.abspath("result.png")}') | |||
| @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') | |||
| def test_run_modelhub(self): | |||
| seed = 10 | |||
| face_generation = pipeline( | |||
| Tasks.image_generation, | |||
| model=self.model_id, | |||
| ) | |||
| self.pipeline_inference(face_generation, seed) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||