Browse Source

[to #42322933] 人脸生成

添加了基于stylegan2的人像生成算法
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9426478
master
baiguan.yt yingda.chen 3 years ago
parent
commit
51bd47a72c
9 changed files with 1390 additions and 0 deletions
  1. +1
    -0
      modelscope/metainfo.py
  2. +2
    -0
      modelscope/models/cv/face_generation/op/__init__.py
  3. +229
    -0
      modelscope/models/cv/face_generation/op/conv2d_gradfix.py
  4. +113
    -0
      modelscope/models/cv/face_generation/op/fused_act.py
  5. +197
    -0
      modelscope/models/cv/face_generation/op/upfirdn2d.py
  6. +731
    -0
      modelscope/models/cv/face_generation/stylegan2.py
  7. +1
    -0
      modelscope/pipelines/cv/__init__.py
  8. +79
    -0
      modelscope/pipelines/cv/face_image_generation_pipeline.py
  9. +37
    -0
      tests/pipelines/test_face_image_generation.py

+ 1
- 0
modelscope/metainfo.py View File

@@ -49,6 +49,7 @@ class Pipelines(object):
action_recognition = 'TAdaConv_action-recognition'
animal_recognation = 'resnet101-animal_recog'
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
face_image_generation = 'gan-face-image-generation'
style_transfer = 'AAMS-style-transfer'

# nlp tasks


+ 2
- 0
modelscope/models/cv/face_generation/op/__init__.py View File

@@ -0,0 +1,2 @@
from .fused_act import FusedLeakyReLU, fused_leaky_relu
from .upfirdn2d import upfirdn2d

+ 229
- 0
modelscope/models/cv/face_generation/op/conv2d_gradfix.py View File

@@ -0,0 +1,229 @@
import contextlib
import warnings

import torch
from torch import autograd
from torch.nn import functional as F

enabled = True
weight_gradients_disabled = False


@contextlib.contextmanager
def no_weight_gradients():
global weight_gradients_disabled

old = weight_gradients_disabled
weight_gradients_disabled = True
yield
weight_gradients_disabled = old


def conv2d(input,
weight,
bias=None,
stride=1,
padding=0,
dilation=1,
groups=1):
if could_use_op(input):
return conv2d_gradfix(
transpose=False,
weight_shape=weight.shape,
stride=stride,
padding=padding,
output_padding=0,
dilation=dilation,
groups=groups,
).apply(input, weight, bias)

return F.conv2d(
input=input,
weight=weight,
bias=bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)


def conv_transpose2d(
input,
weight,
bias=None,
stride=1,
padding=0,
output_padding=0,
groups=1,
dilation=1,
):
if could_use_op(input):
return conv2d_gradfix(
transpose=True,
weight_shape=weight.shape,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
dilation=dilation,
).apply(input, weight, bias)

return F.conv_transpose2d(
input=input,
weight=weight,
bias=bias,
stride=stride,
padding=padding,
output_padding=output_padding,
dilation=dilation,
groups=groups,
)


def could_use_op(input):
if (not enabled) or (not torch.backends.cudnn.enabled):
return False

if input.device.type != 'cuda':
return False

if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.']):
return True

warnings.warn(
f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().'
)

return False


def ensure_tuple(xs, ndim):
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs, ) * ndim

return xs


conv2d_gradfix_cache = dict()


def conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding,
dilation, groups):
ndim = 2
weight_shape = tuple(weight_shape)
stride = ensure_tuple(stride, ndim)
padding = ensure_tuple(padding, ndim)
output_padding = ensure_tuple(output_padding, ndim)
dilation = ensure_tuple(dilation, ndim)

key = (transpose, weight_shape, stride, padding, output_padding, dilation,
groups)
if key in conv2d_gradfix_cache:
return conv2d_gradfix_cache[key]

common_kwargs = dict(
stride=stride, padding=padding, dilation=dilation, groups=groups)

def calc_output_padding(input_shape, output_shape):
if transpose:
return [0, 0]

a = input_shape[i + 2] - (output_shape[i + 2] - 1) * stride[i]
return [
a - (1 - 2 * padding[i]) - dilation[i] * (weight_shape[i + 2] - 1)
for i in range(ndim)
]

class Conv2d(autograd.Function):

@staticmethod
def forward(ctx, input, weight, bias):
if not transpose:
out = F.conv2d(
input=input, weight=weight, bias=bias, **common_kwargs)

else:
out = F.conv_transpose2d(
input=input,
weight=weight,
bias=bias,
output_padding=output_padding,
**common_kwargs,
)

ctx.save_for_backward(input, weight)

return out

@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
grad_input, grad_weight, grad_bias = None, None, None

if ctx.needs_input_grad[0]:
p = calc_output_padding(
input_shape=input.shape, output_shape=grad_output.shape)
grad_input = conv2d_gradfix(
transpose=(not transpose),
weight_shape=weight_shape,
output_padding=p,
**common_kwargs,
).apply(grad_output, weight, None)

if ctx.needs_input_grad[1] and not weight_gradients_disabled:
grad_weight = Conv2dGradWeight.apply(grad_output, input)

if ctx.needs_input_grad[2]:
grad_bias = grad_output.sum((0, 2, 3))

return grad_input, grad_weight, grad_bias

class Conv2dGradWeight(autograd.Function):

@staticmethod
def forward(ctx, grad_output, input):
op = torch._C._jit_get_operation(
'aten::cudnn_convolution_backward_weight' if not transpose else
'aten::cudnn_convolution_transpose_backward_weight')
flags = [
torch.backends.cudnn.benchmark,
torch.backends.cudnn.deterministic,
torch.backends.cudnn.allow_tf32,
]
grad_weight = op(
weight_shape,
grad_output,
input,
padding,
stride,
dilation,
groups,
*flags,
)
ctx.save_for_backward(grad_output, input)

return grad_weight

@staticmethod
def backward(ctx, grad_grad_weight):
grad_output, input = ctx.saved_tensors
grad_grad_output, grad_grad_input = None, None

if ctx.needs_input_grad[0]:
grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)

if ctx.needs_input_grad[1]:
p = calc_output_padding(
input_shape=input.shape, output_shape=grad_output.shape)
grad_grad_input = conv2d_gradfix(
transpose=(not transpose),
weight_shape=weight_shape,
output_padding=p,
**common_kwargs,
).apply(grad_output, grad_grad_weight, None)

return grad_grad_output, grad_grad_input

conv2d_gradfix_cache[key] = Conv2d

return Conv2d

+ 113
- 0
modelscope/models/cv/face_generation/op/fused_act.py View File

@@ -0,0 +1,113 @@
import os

import torch
from torch import nn
from torch.autograd import Function
from torch.nn import functional as F

def_lib = False


class FusedLeakyReLUFunctionBackward(Function):

@staticmethod
def forward(ctx, grad_output, out, bias, negative_slope, scale):
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale

empty = grad_output.new_empty(0)

grad_input = fused.fused_bias_act(grad_output.contiguous(), empty, out,
3, 1, negative_slope, scale)

dim = [0]

if grad_input.ndim > 2:
dim += list(range(2, grad_input.ndim))

if bias:
grad_bias = grad_input.sum(dim).detach()

else:
grad_bias = empty

return grad_input, grad_bias

@staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias):
out, = ctx.saved_tensors
gradgrad_out = fused.fused_bias_act(
gradgrad_input.contiguous(),
gradgrad_bias,
out,
3,
1,
ctx.negative_slope,
ctx.scale,
)

return gradgrad_out, None, None, None, None


class FusedLeakyReLUFunction(Function):

@staticmethod
def forward(ctx, input, bias, negative_slope, scale):
empty = input.new_empty(0)

ctx.bias = bias is not None

if bias is None:
bias = empty

out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope,
scale)
ctx.save_for_backward(out)
ctx.negative_slope = negative_slope
ctx.scale = scale

return out

@staticmethod
def backward(ctx, grad_output):
out, = ctx.saved_tensors

grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale)

if not ctx.bias:
grad_bias = None

return grad_input, grad_bias, None, None


class FusedLeakyReLU(nn.Module):

def __init__(self, channel, bias=True, negative_slope=0.2, scale=2**0.5):
super().__init__()

if bias:
self.bias = nn.Parameter(torch.zeros(channel))

else:
self.bias = None

self.negative_slope = negative_slope
self.scale = scale

def forward(self, input):
return fused_leaky_relu(input, self.bias, self.negative_slope,
self.scale)


def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5):
if not def_lib:
if bias is not None:
rest_dim = [1] * (input.ndim - bias.ndim - 1)
return (F.leaky_relu(
input + bias.view(1, bias.shape[0], *rest_dim),
negative_slope=0.2) * scale)

else:
return F.leaky_relu(input, negative_slope=0.2) * scale

+ 197
- 0
modelscope/models/cv/face_generation/op/upfirdn2d.py View File

@@ -0,0 +1,197 @@
import os
from collections import abc

import torch
from torch.autograd import Function
from torch.nn import functional as F

def_lib = False


class UpFirDn2dBackward(Function):

@staticmethod
def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad,
in_size, out_size):

up_x, up_y = up
down_x, down_y = down
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad

grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)

grad_input = upfirdn2d_op.upfirdn2d(
grad_output,
grad_kernel,
down_x,
down_y,
up_x,
up_y,
g_pad_x0,
g_pad_x1,
g_pad_y0,
g_pad_y1,
)
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2],
in_size[3])

ctx.save_for_backward(kernel)

pad_x0, pad_x1, pad_y0, pad_y1 = pad

ctx.up_x = up_x
ctx.up_y = up_y
ctx.down_x = down_x
ctx.down_y = down_y
ctx.pad_x0 = pad_x0
ctx.pad_x1 = pad_x1
ctx.pad_y0 = pad_y0
ctx.pad_y1 = pad_y1
ctx.in_size = in_size
ctx.out_size = out_size

return grad_input

@staticmethod
def backward(ctx, gradgrad_input):
kernel, = ctx.saved_tensors

gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2],
ctx.in_size[3], 1)

gradgrad_out = upfirdn2d_op.upfirdn2d(
gradgrad_input,
kernel,
ctx.up_x,
ctx.up_y,
ctx.down_x,
ctx.down_y,
ctx.pad_x0,
ctx.pad_x1,
ctx.pad_y0,
ctx.pad_y1,
)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1],
ctx.out_size[0], ctx.out_size[1])

return gradgrad_out, None, None, None, None, None, None, None, None


class UpFirDn2d(Function):

@staticmethod
def forward(ctx, input, kernel, up, down, pad):
up_x, up_y = up
down_x, down_y = down
pad_x0, pad_x1, pad_y0, pad_y1 = pad

kernel_h, kernel_w = kernel.shape
batch, channel, in_h, in_w = input.shape
ctx.in_size = input.shape

input = input.reshape(-1, in_h, in_w, 1)

ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))

out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
ctx.out_size = (out_h, out_w)

ctx.up = (up_x, up_y)
ctx.down = (down_x, down_y)
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)

g_pad_x0 = kernel_w - pad_x0 - 1
g_pad_y0 = kernel_h - pad_y0 - 1
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1

ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)

out = upfirdn2d_op.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y,
pad_x0, pad_x1, pad_y0, pad_y1)
# out = out.view(major, out_h, out_w, minor)
out = out.view(-1, channel, out_h, out_w)

return out

@staticmethod
def backward(ctx, grad_output):
kernel, grad_kernel = ctx.saved_tensors

grad_input = None

if ctx.needs_input_grad[0]:
grad_input = UpFirDn2dBackward.apply(
grad_output,
kernel,
grad_kernel,
ctx.up,
ctx.down,
ctx.pad,
ctx.g_pad,
ctx.in_size,
ctx.out_size,
)

return grad_input, None, None, None, None


def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
if not isinstance(up, abc.Iterable):
up = (up, up)

if not isinstance(down, abc.Iterable):
down = (down, down)

if len(pad) == 2:
pad = (pad[0], pad[1], pad[0], pad[1])

if not def_lib:
out = upfirdn2d_native(input, kernel, *up, *down, *pad)

return out


def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
pad_y0, pad_y1):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)

_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape

out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)

out = F.pad(
out,
[0, 0,
max(pad_x0, 0),
max(pad_x1, 0),
max(pad_y0, 0),
max(pad_y1, 0)])
out = out[:,
max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0)]

out = out.permute(0, 3, 1, 2)
out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]

out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x

return out.view(-1, channel, out_h, out_w)

+ 731
- 0
modelscope/models/cv/face_generation/stylegan2.py View File

@@ -0,0 +1,731 @@
import functools
import math
import operator
import random

import torch
from torch import nn
from torch.autograd import Function
from torch.nn import functional as F

from .op import FusedLeakyReLU, conv2d_gradfix, fused_leaky_relu, upfirdn2d


class PixelNorm(nn.Module):

def __init__(self):
super().__init__()

def forward(self, input):
return input * torch.rsqrt(
torch.mean(input**2, dim=1, keepdim=True) + 1e-8)


def make_kernel(k):
k = torch.tensor(k, dtype=torch.float32)

if k.ndim == 1:
k = k[None, :] * k[:, None]

k /= k.sum()

return k


class Upsample(nn.Module):

def __init__(self, kernel, factor=2):
super().__init__()

self.factor = factor
kernel = make_kernel(kernel) * (factor**2)
self.register_buffer('kernel', kernel)

p = kernel.shape[0] - factor

pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2

self.pad = (pad0, pad1)

def forward(self, input):
out = upfirdn2d(
input, self.kernel, up=self.factor, down=1, pad=self.pad)

return out


class Downsample(nn.Module):

def __init__(self, kernel, factor=2):
super().__init__()

self.factor = factor
kernel = make_kernel(kernel)
self.register_buffer('kernel', kernel)

p = kernel.shape[0] - factor

pad0 = (p + 1) // 2
pad1 = p // 2

self.pad = (pad0, pad1)

def forward(self, input):
out = upfirdn2d(
input, self.kernel, up=1, down=self.factor, pad=self.pad)

return out


class Blur(nn.Module):

def __init__(self, kernel, pad, upsample_factor=1):
super().__init__()

kernel = make_kernel(kernel)

if upsample_factor > 1:
kernel = kernel * (upsample_factor**2)

self.register_buffer('kernel', kernel)

self.pad = pad

def forward(self, input):
out = upfirdn2d(input, self.kernel, pad=self.pad)

return out


class EqualConv2d(nn.Module):

def __init__(self,
in_channel,
out_channel,
kernel_size,
stride=1,
padding=0,
bias=True):
super().__init__()

self.weight = nn.Parameter(
torch.randn(out_channel, in_channel, kernel_size, kernel_size))
self.scale = 1 / math.sqrt(in_channel * kernel_size**2)

self.stride = stride
self.padding = padding

if bias:
self.bias = nn.Parameter(torch.zeros(out_channel))

else:
self.bias = None

def forward(self, input):
out = conv2d_gradfix.conv2d(
input,
self.weight * self.scale,
bias=self.bias,
stride=self.stride,
padding=self.padding,
)

return out

def __repr__(self):
return (
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
)


class EqualLinear(nn.Module):

def __init__(self,
in_dim,
out_dim,
bias=True,
bias_init=0,
lr_mul=1,
activation=None):
super().__init__()

self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))

if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))

else:
self.bias = None

self.activation = activation

self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul

def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)

else:
out = F.linear(
input, self.weight * self.scale, bias=self.bias * self.lr_mul)

return out

def __repr__(self):
return (
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
)


class ModulatedConv2d(nn.Module):

def __init__(
self,
in_channel,
out_channel,
kernel_size,
style_dim,
demodulate=True,
upsample=False,
downsample=False,
blur_kernel=[1, 3, 3, 1],
fused=True,
):
super().__init__()

self.eps = 1e-8
self.kernel_size = kernel_size
self.in_channel = in_channel
self.out_channel = out_channel
self.upsample = upsample
self.downsample = downsample

if upsample:
factor = 2
p = (len(blur_kernel) - factor) - (kernel_size - 1)
pad0 = (p + 1) // 2 + factor - 1
pad1 = p // 2 + 1

self.blur = Blur(
blur_kernel, pad=(pad0, pad1), upsample_factor=factor)

if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2

self.blur = Blur(blur_kernel, pad=(pad0, pad1))

fan_in = in_channel * kernel_size**2
self.scale = 1 / math.sqrt(fan_in)
self.padding = kernel_size // 2

self.weight = nn.Parameter(
torch.randn(1, out_channel, in_channel, kernel_size, kernel_size))

self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)

self.demodulate = demodulate
self.fused = fused

def __repr__(self):
return (
f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
f'upsample={self.upsample}, downsample={self.downsample})')

def forward(self, input, style):
batch, in_channel, height, width = input.shape

if not self.fused:
weight = self.scale * self.weight.squeeze(0)
style = self.modulation(style)

if self.demodulate:
w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1,
1)
dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()

input = input * style.reshape(batch, in_channel, 1, 1)

if self.upsample:
weight = weight.transpose(0, 1)
out = conv2d_gradfix.conv_transpose2d(
input, weight, padding=0, stride=2)
out = self.blur(out)

elif self.downsample:
input = self.blur(input)
out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)

else:
out = conv2d_gradfix.conv2d(
input, weight, padding=self.padding)

if self.demodulate:
out = out * dcoefs.view(batch, -1, 1, 1)

return out

style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
weight = self.scale * self.weight * style

if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)

weight = weight.view(batch * self.out_channel, in_channel,
self.kernel_size, self.kernel_size)

if self.upsample:
input = input.view(1, batch * in_channel, height, width)
weight = weight.view(batch, self.out_channel, in_channel,
self.kernel_size, self.kernel_size)
weight = weight.transpose(1, 2).reshape(batch * in_channel,
self.out_channel,
self.kernel_size,
self.kernel_size)
out = conv2d_gradfix.conv_transpose2d(
input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
out = self.blur(out)

elif self.downsample:
input = self.blur(input)
_, _, height, width = input.shape
input = input.view(1, batch * in_channel, height, width)
out = conv2d_gradfix.conv2d(
input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)

else:
input = input.view(1, batch * in_channel, height, width)
out = conv2d_gradfix.conv2d(
input, weight, padding=self.padding, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)

return out


class NoiseInjection(nn.Module):

def __init__(self):
super().__init__()

self.weight = nn.Parameter(torch.zeros(1))

def forward(self, image, noise=None):
if noise is None:
batch, _, height, width = image.shape
noise = image.new_empty(batch, 1, height, width).normal_()

return image + self.weight * noise


class ConstantInput(nn.Module):

def __init__(self, channel, size=4):
super().__init__()

self.input = nn.Parameter(torch.randn(1, channel, size, size))

def forward(self, input):
batch = input.shape[0]
out = self.input.repeat(batch, 1, 1, 1)

return out


class StyledConv(nn.Module):

def __init__(
self,
in_channel,
out_channel,
kernel_size,
style_dim,
upsample=False,
blur_kernel=[1, 3, 3, 1],
demodulate=True,
):
super().__init__()

self.conv = ModulatedConv2d(
in_channel,
out_channel,
kernel_size,
style_dim,
upsample=upsample,
blur_kernel=blur_kernel,
demodulate=demodulate,
)

self.noise = NoiseInjection()
# self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
# self.activate = ScaledLeakyReLU(0.2)
self.activate = FusedLeakyReLU(out_channel)

def forward(self, input, style, noise=None):
out = self.conv(input, style)
out = self.noise(out, noise=noise)
# out = out + self.bias
out = self.activate(out)

return out


class ToRGB(nn.Module):

def __init__(self,
in_channel,
style_dim,
upsample=True,
blur_kernel=[1, 3, 3, 1]):
super().__init__()

if upsample:
self.upsample = Upsample(blur_kernel)

self.conv = ModulatedConv2d(
in_channel, 3, 1, style_dim, demodulate=False)
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))

def forward(self, input, style, skip=None):
out = self.conv(input, style)
out = out + self.bias

if skip is not None:
skip = self.upsample(skip)

out = out + skip

return out


class Generator(nn.Module):

def __init__(
self,
size,
style_dim,
n_mlp,
channel_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
):
super().__init__()

self.size = size

self.style_dim = style_dim

layers = [PixelNorm()]

for i in range(n_mlp):
layers.append(
EqualLinear(
style_dim,
style_dim,
lr_mul=lr_mlp,
activation='fused_lrelu'))

self.style = nn.Sequential(*layers)

self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}

self.input = ConstantInput(self.channels[4])
self.conv1 = StyledConv(
self.channels[4],
self.channels[4],
3,
style_dim,
blur_kernel=blur_kernel)
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)

self.log_size = int(math.log(size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1

self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()

in_channel = self.channels[4]

for layer_idx in range(self.num_layers):
res = (layer_idx + 5) // 2
shape = [1, 1, 2**res, 2**res]
self.noises.register_buffer(f'noise_{layer_idx}',
torch.randn(*shape))

for i in range(3, self.log_size + 1):
out_channel = self.channels[2**i]

self.convs.append(
StyledConv(
in_channel,
out_channel,
3,
style_dim,
upsample=True,
blur_kernel=blur_kernel,
))

self.convs.append(
StyledConv(
out_channel,
out_channel,
3,
style_dim,
blur_kernel=blur_kernel))

self.to_rgbs.append(ToRGB(out_channel, style_dim))

in_channel = out_channel

self.n_latent = self.log_size * 2 - 2

def make_noise(self):
device = self.input.input.device

noises = [torch.randn(1, 1, 2**2, 2**2, device=device)]

for i in range(3, self.log_size + 1):
for _ in range(2):
noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))

return noises

def mean_latent(self, n_latent):
latent_in = torch.randn(
n_latent, self.style_dim, device=self.input.input.device)
latent = self.style(latent_in).mean(0, keepdim=True)

return latent

def get_latent(self, input):
return self.style(input)

def forward(
self,
styles,
return_latents=False,
inject_index=None,
truncation=1,
truncation_latent=None,
input_is_latent=False,
noise=None,
randomize_noise=True,
):
if not input_is_latent:
styles = [self.style(s) for s in styles]

if noise is None:
if randomize_noise:
noise = [None] * self.num_layers
else:
noise = [
getattr(self.noises, f'noise_{i}')
for i in range(self.num_layers)
]

if truncation < 1:
style_t = []

for style in styles:
style_t.append(truncation_latent
+ truncation * (style - truncation_latent))

styles = style_t

if len(styles) < 2:
inject_index = self.n_latent

if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)

else:
latent = styles[0]

else:
if inject_index is None:
inject_index = random.randint(1, self.n_latent - 1)

latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(
1, self.n_latent - inject_index, 1)

latent = torch.cat([latent, latent2], 1)

out = self.input(latent)
out = self.conv1(out, latent[:, 0], noise=noise[0])

skip = self.to_rgb1(out, latent[:, 1])

i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2],
self.to_rgbs):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)

i += 2

image = skip

if return_latents:
return image, latent

else:
return image, None


class ConvLayer(nn.Sequential):

def __init__(
self,
in_channel,
out_channel,
kernel_size,
downsample=False,
blur_kernel=[1, 3, 3, 1],
bias=True,
activate=True,
):
layers = []

if downsample:
factor = 2
p = (len(blur_kernel) - factor) + (kernel_size - 1)
pad0 = (p + 1) // 2
pad1 = p // 2

layers.append(Blur(blur_kernel, pad=(pad0, pad1)))

stride = 2
self.padding = 0

else:
stride = 1
self.padding = kernel_size // 2

layers.append(
EqualConv2d(
in_channel,
out_channel,
kernel_size,
padding=self.padding,
stride=stride,
bias=bias and not activate,
))

if activate:
layers.append(FusedLeakyReLU(out_channel, bias=bias))

super().__init__(*layers)


class ResBlock(nn.Module):

def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
super().__init__()

self.conv1 = ConvLayer(in_channel, in_channel, 3)
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)

self.skip = ConvLayer(
in_channel,
out_channel,
1,
downsample=True,
activate=False,
bias=False)

def forward(self, input):
out = self.conv1(input)
out = self.conv2(out)

skip = self.skip(input)
out = (out + skip) / math.sqrt(2)

return out


class Discriminator(nn.Module):

def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
super().__init__()

channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}

convs = [ConvLayer(3, channels[size], 1)]

log_size = int(math.log(size, 2))

in_channel = channels[size]

for i in range(log_size, 2, -1):
out_channel = channels[2**(i - 1)]

convs.append(ResBlock(in_channel, out_channel, blur_kernel))

in_channel = out_channel

self.convs = nn.Sequential(*convs)

self.stddev_group = 4
self.stddev_feat = 1

self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
self.final_linear = nn.Sequential(
EqualLinear(
channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
EqualLinear(channels[4], 1),
)

def forward(self, input):
out = self.convs(input)

batch, channel, height, width = out.shape
group = min(batch, self.stddev_group)
stddev = out.view(group, -1, self.stddev_feat,
channel // self.stddev_feat, height, width)
stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
stddev = stddev.repeat(group, 1, height, width)
out = torch.cat([out, stddev], 1)

out = self.final_conv(out)

out = out.view(batch, -1)
out = self.final_linear(out)

return out

+ 1
- 0
modelscope/pipelines/cv/__init__.py View File

@@ -6,6 +6,7 @@ try:
from .action_recognition_pipeline import ActionRecognitionPipeline
from .animal_recog_pipeline import AnimalRecogPipeline
from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline
from .face_image_generation_pipeline import FaceImageGenerationPipeline
except ModuleNotFoundError as e:
if str(e) == "No module named 'torch'":
pass


+ 79
- 0
modelscope/pipelines/cv/face_image_generation_pipeline.py View File

@@ -0,0 +1,79 @@
import os
from typing import Any, Dict

import cv2
import numpy as np
import PIL
import torch

from modelscope.metainfo import Pipelines
from modelscope.models.cv.face_generation import stylegan2
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input
from modelscope.preprocessors import load_image
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES

logger = get_logger()


@PIPELINES.register_module(
Tasks.image_generation, module_name=Pipelines.face_image_generation)
class FaceImageGenerationPipeline(Pipeline):

def __init__(self, model: str):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model)
self.size = 1024
self.latent = 512
self.n_mlp = 8
self.channel_multiplier = 2
self.truncation = 0.7
self.truncation_mean = 4096
self.generator = stylegan2.Generator(
self.size,
self.latent,
self.n_mlp,
channel_multiplier=self.channel_multiplier)

self.model_file = f'{model}/{ModelFile.TORCH_MODEL_FILE}'

self.generator.load_state_dict(torch.load(self.model_file)['g_ema'])
logger.info('load model done')

self.mean_latent = None
if self.truncation < 1:
with torch.no_grad():
self.mean_latent = self.generator.mean_latent(
self.truncation_mean)

def preprocess(self, input: Input) -> Dict[str, Any]:
return input

def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
assert isinstance(input, int)
torch.manual_seed(input)
torch.cuda.manual_seed(input)
torch.cuda.manual_seed_all(input)
self.generator.eval()
with torch.no_grad():
sample_z = torch.randn(1, self.latent)

sample, _ = self.generator([sample_z],
truncation=self.truncation,
truncation_latent=self.mean_latent)

sample = sample * 0.5 + 0.5
sample = sample.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR
sample = np.clip(sample.float().cpu().numpy(), 0, 1) * 255

return {OutputKeys.OUTPUT_IMG: sample.astype(np.uint8)}

def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

+ 37
- 0
tests/pipelines/test_face_image_generation.py View File

@@ -0,0 +1,37 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import unittest

import cv2

from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level


class FaceGenerationTest(unittest.TestCase):

def setUp(self) -> None:
self.model_id = 'damo/cv_gan_face-image-generation'

def pipeline_inference(self, pipeline: Pipeline, seed: int):
result = pipeline(seed)
if result is not None:
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
print(f'Output written to {osp.abspath("result.png")}')

@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_modelhub(self):
seed = 10
face_generation = pipeline(
Tasks.image_generation,
model=self.model_id,
)
self.pipeline_inference(face_generation, seed)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save