diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 063b4d4f..963fc14f 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -49,6 +49,7 @@ class Pipelines(object): action_recognition = 'TAdaConv_action-recognition' animal_recognation = 'resnet101-animal_recog' cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' + face_image_generation = 'gan-face-image-generation' style_transfer = 'AAMS-style-transfer' # nlp tasks diff --git a/modelscope/models/cv/face_generation/op/__init__.py b/modelscope/models/cv/face_generation/op/__init__.py new file mode 100755 index 00000000..d0918d92 --- /dev/null +++ b/modelscope/models/cv/face_generation/op/__init__.py @@ -0,0 +1,2 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu +from .upfirdn2d import upfirdn2d diff --git a/modelscope/models/cv/face_generation/op/conv2d_gradfix.py b/modelscope/models/cv/face_generation/op/conv2d_gradfix.py new file mode 100755 index 00000000..f2e3fff2 --- /dev/null +++ b/modelscope/models/cv/face_generation/op/conv2d_gradfix.py @@ -0,0 +1,229 @@ +import contextlib +import warnings + +import torch +from torch import autograd +from torch.nn import functional as F + +enabled = True +weight_gradients_disabled = False + + +@contextlib.contextmanager +def no_weight_gradients(): + global weight_gradients_disabled + + old = weight_gradients_disabled + weight_gradients_disabled = True + yield + weight_gradients_disabled = old + + +def conv2d(input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1): + if could_use_op(input): + return conv2d_gradfix( + transpose=False, + weight_shape=weight.shape, + stride=stride, + padding=padding, + output_padding=0, + dilation=dilation, + groups=groups, + ).apply(input, weight, bias) + + return F.conv2d( + input=input, + weight=weight, + bias=bias, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + +def conv_transpose2d( + input, + weight, + bias=None, + stride=1, + padding=0, + output_padding=0, + groups=1, + dilation=1, +): + if could_use_op(input): + return conv2d_gradfix( + transpose=True, + weight_shape=weight.shape, + stride=stride, + padding=padding, + output_padding=output_padding, + groups=groups, + dilation=dilation, + ).apply(input, weight, bias) + + return F.conv_transpose2d( + input=input, + weight=weight, + bias=bias, + stride=stride, + padding=padding, + output_padding=output_padding, + dilation=dilation, + groups=groups, + ) + + +def could_use_op(input): + if (not enabled) or (not torch.backends.cudnn.enabled): + return False + + if input.device.type != 'cuda': + return False + + if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.']): + return True + + warnings.warn( + f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().' + ) + + return False + + +def ensure_tuple(xs, ndim): + xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs, ) * ndim + + return xs + + +conv2d_gradfix_cache = dict() + + +def conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, + dilation, groups): + ndim = 2 + weight_shape = tuple(weight_shape) + stride = ensure_tuple(stride, ndim) + padding = ensure_tuple(padding, ndim) + output_padding = ensure_tuple(output_padding, ndim) + dilation = ensure_tuple(dilation, ndim) + + key = (transpose, weight_shape, stride, padding, output_padding, dilation, + groups) + if key in conv2d_gradfix_cache: + return conv2d_gradfix_cache[key] + + common_kwargs = dict( + stride=stride, padding=padding, dilation=dilation, groups=groups) + + def calc_output_padding(input_shape, output_shape): + if transpose: + return [0, 0] + + a = input_shape[i + 2] - (output_shape[i + 2] - 1) * stride[i] + return [ + a - (1 - 2 * padding[i]) - dilation[i] * (weight_shape[i + 2] - 1) + for i in range(ndim) + ] + + class Conv2d(autograd.Function): + + @staticmethod + def forward(ctx, input, weight, bias): + if not transpose: + out = F.conv2d( + input=input, weight=weight, bias=bias, **common_kwargs) + + else: + out = F.conv_transpose2d( + input=input, + weight=weight, + bias=bias, + output_padding=output_padding, + **common_kwargs, + ) + + ctx.save_for_backward(input, weight) + + return out + + @staticmethod + def backward(ctx, grad_output): + input, weight = ctx.saved_tensors + grad_input, grad_weight, grad_bias = None, None, None + + if ctx.needs_input_grad[0]: + p = calc_output_padding( + input_shape=input.shape, output_shape=grad_output.shape) + grad_input = conv2d_gradfix( + transpose=(not transpose), + weight_shape=weight_shape, + output_padding=p, + **common_kwargs, + ).apply(grad_output, weight, None) + + if ctx.needs_input_grad[1] and not weight_gradients_disabled: + grad_weight = Conv2dGradWeight.apply(grad_output, input) + + if ctx.needs_input_grad[2]: + grad_bias = grad_output.sum((0, 2, 3)) + + return grad_input, grad_weight, grad_bias + + class Conv2dGradWeight(autograd.Function): + + @staticmethod + def forward(ctx, grad_output, input): + op = torch._C._jit_get_operation( + 'aten::cudnn_convolution_backward_weight' if not transpose else + 'aten::cudnn_convolution_transpose_backward_weight') + flags = [ + torch.backends.cudnn.benchmark, + torch.backends.cudnn.deterministic, + torch.backends.cudnn.allow_tf32, + ] + grad_weight = op( + weight_shape, + grad_output, + input, + padding, + stride, + dilation, + groups, + *flags, + ) + ctx.save_for_backward(grad_output, input) + + return grad_weight + + @staticmethod + def backward(ctx, grad_grad_weight): + grad_output, input = ctx.saved_tensors + grad_grad_output, grad_grad_input = None, None + + if ctx.needs_input_grad[0]: + grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) + + if ctx.needs_input_grad[1]: + p = calc_output_padding( + input_shape=input.shape, output_shape=grad_output.shape) + grad_grad_input = conv2d_gradfix( + transpose=(not transpose), + weight_shape=weight_shape, + output_padding=p, + **common_kwargs, + ).apply(grad_output, grad_grad_weight, None) + + return grad_grad_output, grad_grad_input + + conv2d_gradfix_cache[key] = Conv2d + + return Conv2d diff --git a/modelscope/models/cv/face_generation/op/fused_act.py b/modelscope/models/cv/face_generation/op/fused_act.py new file mode 100755 index 00000000..d6e0c10f --- /dev/null +++ b/modelscope/models/cv/face_generation/op/fused_act.py @@ -0,0 +1,113 @@ +import os + +import torch +from torch import nn +from torch.autograd import Function +from torch.nn import functional as F + +def_lib = False + + +class FusedLeakyReLUFunctionBackward(Function): + + @staticmethod + def forward(ctx, grad_output, out, bias, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused.fused_bias_act(grad_output.contiguous(), empty, out, + 3, 1, negative_slope, scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + if bias: + grad_bias = grad_input.sum(dim).detach() + + else: + grad_bias = empty + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused.fused_bias_act( + gradgrad_input.contiguous(), + gradgrad_bias, + out, + 3, + 1, + ctx.negative_slope, + ctx.scale, + ) + + return gradgrad_out, None, None, None, None + + +class FusedLeakyReLUFunction(Function): + + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + + ctx.bias = bias is not None + + if bias is None: + bias = empty + + out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, + scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( + grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale) + + if not ctx.bias: + grad_bias = None + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + + def __init__(self, channel, bias=True, negative_slope=0.2, scale=2**0.5): + super().__init__() + + if bias: + self.bias = nn.Parameter(torch.zeros(channel)) + + else: + self.bias = None + + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, + self.scale) + + +def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2**0.5): + if not def_lib: + if bias is not None: + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return (F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), + negative_slope=0.2) * scale) + + else: + return F.leaky_relu(input, negative_slope=0.2) * scale diff --git a/modelscope/models/cv/face_generation/op/upfirdn2d.py b/modelscope/models/cv/face_generation/op/upfirdn2d.py new file mode 100755 index 00000000..5a44421d --- /dev/null +++ b/modelscope/models/cv/face_generation/op/upfirdn2d.py @@ -0,0 +1,197 @@ +import os +from collections import abc + +import torch +from torch.autograd import Function +from torch.nn import functional as F + +def_lib = False + + +class UpFirDn2dBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, + in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_op.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], + in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], + ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_op.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], + ctx.out_size[0], ctx.out_size[1]) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_op.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, + pad_x0, pad_x1, pad_y0, pad_y1) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = None + + if ctx.needs_input_grad[0]: + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if not isinstance(up, abc.Iterable): + up = (up, up) + + if not isinstance(down, abc.Iterable): + down = (down, down) + + if len(pad) == 2: + pad = (pad[0], pad[1], pad[0], pad[1]) + + if not def_lib: + out = upfirdn2d_native(input, kernel, *up, *down, *pad) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, + pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad( + out, + [0, 0, + max(pad_x0, 0), + max(pad_x1, 0), + max(pad_y0, 0), + max(pad_y1, 0)]) + out = out[:, + max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), + max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0)] + + out = out.permute(0, 3, 1, 2) + out = out.reshape( + [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x + + return out.view(-1, channel, out_h, out_w) diff --git a/modelscope/models/cv/face_generation/stylegan2.py b/modelscope/models/cv/face_generation/stylegan2.py new file mode 100755 index 00000000..ff9c83ee --- /dev/null +++ b/modelscope/models/cv/face_generation/stylegan2.py @@ -0,0 +1,731 @@ +import functools +import math +import operator +import random + +import torch +from torch import nn +from torch.autograd import Function +from torch.nn import functional as F + +from .op import FusedLeakyReLU, conv2d_gradfix, fused_leaky_relu, upfirdn2d + + +class PixelNorm(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input): + return input * torch.rsqrt( + torch.mean(input**2, dim=1, keepdim=True) + 1e-8) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + + if k.ndim == 1: + k = k[None, :] * k[:, None] + + k /= k.sum() + + return k + + +class Upsample(nn.Module): + + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) * (factor**2) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d( + input, self.kernel, up=self.factor, down=1, pad=self.pad) + + return out + + +class Downsample(nn.Module): + + def __init__(self, kernel, factor=2): + super().__init__() + + self.factor = factor + kernel = make_kernel(kernel) + self.register_buffer('kernel', kernel) + + p = kernel.shape[0] - factor + + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.pad = (pad0, pad1) + + def forward(self, input): + out = upfirdn2d( + input, self.kernel, up=1, down=self.factor, pad=self.pad) + + return out + + +class Blur(nn.Module): + + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor**2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + out = upfirdn2d(input, self.kernel, pad=self.pad) + + return out + + +class EqualConv2d(nn.Module): + + def __init__(self, + in_channel, + out_channel, + kernel_size, + stride=1, + padding=0, + bias=True): + super().__init__() + + self.weight = nn.Parameter( + torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size**2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + + else: + self.bias = None + + def forward(self, input): + out = conv2d_gradfix.conv2d( + input, + self.weight * self.scale, + bias=self.bias, + stride=self.stride, + padding=self.padding, + ) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + + def __init__(self, + in_dim, + out_dim, + bias=True, + bias_init=0, + lr_mul=1, + activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + + else: + out = F.linear( + input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' + ) + + +class ModulatedConv2d(nn.Module): + + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + demodulate=True, + upsample=False, + downsample=False, + blur_kernel=[1, 3, 3, 1], + fused=True, + ): + super().__init__() + + self.eps = 1e-8 + self.kernel_size = kernel_size + self.in_channel = in_channel + self.out_channel = out_channel + self.upsample = upsample + self.downsample = downsample + + if upsample: + factor = 2 + p = (len(blur_kernel) - factor) - (kernel_size - 1) + pad0 = (p + 1) // 2 + factor - 1 + pad1 = p // 2 + 1 + + self.blur = Blur( + blur_kernel, pad=(pad0, pad1), upsample_factor=factor) + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + self.blur = Blur(blur_kernel, pad=(pad0, pad1)) + + fan_in = in_channel * kernel_size**2 + self.scale = 1 / math.sqrt(fan_in) + self.padding = kernel_size // 2 + + self.weight = nn.Parameter( + torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)) + + self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) + + self.demodulate = demodulate + self.fused = fused + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' + f'upsample={self.upsample}, downsample={self.downsample})') + + def forward(self, input, style): + batch, in_channel, height, width = input.shape + + if not self.fused: + weight = self.scale * self.weight.squeeze(0) + style = self.modulation(style) + + if self.demodulate: + w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, + 1) + dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt() + + input = input * style.reshape(batch, in_channel, 1, 1) + + if self.upsample: + weight = weight.transpose(0, 1) + out = conv2d_gradfix.conv_transpose2d( + input, weight, padding=0, stride=2) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2) + + else: + out = conv2d_gradfix.conv2d( + input, weight, padding=self.padding) + + if self.demodulate: + out = out * dcoefs.view(batch, -1, 1, 1) + + return out + + style = self.modulation(style).view(batch, 1, in_channel, 1, 1) + weight = self.scale * self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) + weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) + + weight = weight.view(batch * self.out_channel, in_channel, + self.kernel_size, self.kernel_size) + + if self.upsample: + input = input.view(1, batch * in_channel, height, width) + weight = weight.view(batch, self.out_channel, in_channel, + self.kernel_size, self.kernel_size) + weight = weight.transpose(1, 2).reshape(batch * in_channel, + self.out_channel, + self.kernel_size, + self.kernel_size) + out = conv2d_gradfix.conv_transpose2d( + input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + out = self.blur(out) + + elif self.downsample: + input = self.blur(input) + _, _, height, width = input.shape + input = input.view(1, batch * in_channel, height, width) + out = conv2d_gradfix.conv2d( + input, weight, padding=0, stride=2, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + else: + input = input.view(1, batch * in_channel, height, width) + out = conv2d_gradfix.conv2d( + input, weight, padding=self.padding, groups=batch) + _, _, height, width = out.shape + out = out.view(batch, self.out_channel, height, width) + + return out + + +class NoiseInjection(nn.Module): + + def __init__(self): + super().__init__() + + self.weight = nn.Parameter(torch.zeros(1)) + + def forward(self, image, noise=None): + if noise is None: + batch, _, height, width = image.shape + noise = image.new_empty(batch, 1, height, width).normal_() + + return image + self.weight * noise + + +class ConstantInput(nn.Module): + + def __init__(self, channel, size=4): + super().__init__() + + self.input = nn.Parameter(torch.randn(1, channel, size, size)) + + def forward(self, input): + batch = input.shape[0] + out = self.input.repeat(batch, 1, 1, 1) + + return out + + +class StyledConv(nn.Module): + + def __init__( + self, + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=False, + blur_kernel=[1, 3, 3, 1], + demodulate=True, + ): + super().__init__() + + self.conv = ModulatedConv2d( + in_channel, + out_channel, + kernel_size, + style_dim, + upsample=upsample, + blur_kernel=blur_kernel, + demodulate=demodulate, + ) + + self.noise = NoiseInjection() + # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) + # self.activate = ScaledLeakyReLU(0.2) + self.activate = FusedLeakyReLU(out_channel) + + def forward(self, input, style, noise=None): + out = self.conv(input, style) + out = self.noise(out, noise=noise) + # out = out + self.bias + out = self.activate(out) + + return out + + +class ToRGB(nn.Module): + + def __init__(self, + in_channel, + style_dim, + upsample=True, + blur_kernel=[1, 3, 3, 1]): + super().__init__() + + if upsample: + self.upsample = Upsample(blur_kernel) + + self.conv = ModulatedConv2d( + in_channel, 3, 1, style_dim, demodulate=False) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, input, style, skip=None): + out = self.conv(input, style) + out = out + self.bias + + if skip is not None: + skip = self.upsample(skip) + + out = out + skip + + return out + + +class Generator(nn.Module): + + def __init__( + self, + size, + style_dim, + n_mlp, + channel_multiplier=2, + blur_kernel=[1, 3, 3, 1], + lr_mlp=0.01, + ): + super().__init__() + + self.size = size + + self.style_dim = style_dim + + layers = [PixelNorm()] + + for i in range(n_mlp): + layers.append( + EqualLinear( + style_dim, + style_dim, + lr_mul=lr_mlp, + activation='fused_lrelu')) + + self.style = nn.Sequential(*layers) + + self.channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + self.input = ConstantInput(self.channels[4]) + self.conv1 = StyledConv( + self.channels[4], + self.channels[4], + 3, + style_dim, + blur_kernel=blur_kernel) + self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) + + self.log_size = int(math.log(size, 2)) + self.num_layers = (self.log_size - 2) * 2 + 1 + + self.convs = nn.ModuleList() + self.upsamples = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + in_channel = self.channels[4] + + for layer_idx in range(self.num_layers): + res = (layer_idx + 5) // 2 + shape = [1, 1, 2**res, 2**res] + self.noises.register_buffer(f'noise_{layer_idx}', + torch.randn(*shape)) + + for i in range(3, self.log_size + 1): + out_channel = self.channels[2**i] + + self.convs.append( + StyledConv( + in_channel, + out_channel, + 3, + style_dim, + upsample=True, + blur_kernel=blur_kernel, + )) + + self.convs.append( + StyledConv( + out_channel, + out_channel, + 3, + style_dim, + blur_kernel=blur_kernel)) + + self.to_rgbs.append(ToRGB(out_channel, style_dim)) + + in_channel = out_channel + + self.n_latent = self.log_size * 2 - 2 + + def make_noise(self): + device = self.input.input.device + + noises = [torch.randn(1, 1, 2**2, 2**2, device=device)] + + for i in range(3, self.log_size + 1): + for _ in range(2): + noises.append(torch.randn(1, 1, 2**i, 2**i, device=device)) + + return noises + + def mean_latent(self, n_latent): + latent_in = torch.randn( + n_latent, self.style_dim, device=self.input.input.device) + latent = self.style(latent_in).mean(0, keepdim=True) + + return latent + + def get_latent(self, input): + return self.style(input) + + def forward( + self, + styles, + return_latents=False, + inject_index=None, + truncation=1, + truncation_latent=None, + input_is_latent=False, + noise=None, + randomize_noise=True, + ): + if not input_is_latent: + styles = [self.style(s) for s in styles] + + if noise is None: + if randomize_noise: + noise = [None] * self.num_layers + else: + noise = [ + getattr(self.noises, f'noise_{i}') + for i in range(self.num_layers) + ] + + if truncation < 1: + style_t = [] + + for style in styles: + style_t.append(truncation_latent + + truncation * (style - truncation_latent)) + + styles = style_t + + if len(styles) < 2: + inject_index = self.n_latent + + if styles[0].ndim < 3: + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + + else: + latent = styles[0] + + else: + if inject_index is None: + inject_index = random.randint(1, self.n_latent - 1) + + latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) + latent2 = styles[1].unsqueeze(1).repeat( + 1, self.n_latent - inject_index, 1) + + latent = torch.cat([latent, latent2], 1) + + out = self.input(latent) + out = self.conv1(out, latent[:, 0], noise=noise[0]) + + skip = self.to_rgb1(out, latent[:, 1]) + + i = 1 + for conv1, conv2, noise1, noise2, to_rgb in zip( + self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], + self.to_rgbs): + out = conv1(out, latent[:, i], noise=noise1) + out = conv2(out, latent[:, i + 1], noise=noise2) + skip = to_rgb(out, latent[:, i + 2], skip) + + i += 2 + + image = skip + + if return_latents: + return image, latent + + else: + return image, None + + +class ConvLayer(nn.Sequential): + + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append( + EqualConv2d( + in_channel, + out_channel, + kernel_size, + padding=self.padding, + stride=stride, + bias=bias and not activate, + )) + + if activate: + layers.append(FusedLeakyReLU(out_channel, bias=bias)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer( + in_channel, + out_channel, + 1, + downsample=True, + activate=False, + bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class Discriminator(nn.Module): + + def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256 * channel_multiplier, + 128: 128 * channel_multiplier, + 256: 64 * channel_multiplier, + 512: 32 * channel_multiplier, + 1024: 16 * channel_multiplier, + } + + convs = [ConvLayer(3, channels[size], 1)] + + log_size = int(math.log(size, 2)) + + in_channel = channels[size] + + for i in range(log_size, 2, -1): + out_channel = channels[2**(i - 1)] + + convs.append(ResBlock(in_channel, out_channel, blur_kernel)) + + in_channel = out_channel + + self.convs = nn.Sequential(*convs) + + self.stddev_group = 4 + self.stddev_feat = 1 + + self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) + self.final_linear = nn.Sequential( + EqualLinear( + channels[4] * 4 * 4, channels[4], activation='fused_lrelu'), + EqualLinear(channels[4], 1), + ) + + def forward(self, input): + out = self.convs(input) + + batch, channel, height, width = out.shape + group = min(batch, self.stddev_group) + stddev = out.view(group, -1, self.stddev_feat, + channel // self.stddev_feat, height, width) + stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) + stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) + stddev = stddev.repeat(group, 1, height, width) + out = torch.cat([out, stddev], 1) + + out = self.final_conv(out) + + out = out.view(batch, -1) + out = self.final_linear(out) + + return out diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index b4b27b4b..bc1f7c49 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -6,6 +6,7 @@ try: from .action_recognition_pipeline import ActionRecognitionPipeline from .animal_recog_pipeline import AnimalRecogPipeline from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline + from .face_image_generation_pipeline import FaceImageGenerationPipeline except ModuleNotFoundError as e: if str(e) == "No module named 'torch'": pass diff --git a/modelscope/pipelines/cv/face_image_generation_pipeline.py b/modelscope/pipelines/cv/face_image_generation_pipeline.py new file mode 100644 index 00000000..7ae5de4f --- /dev/null +++ b/modelscope/pipelines/cv/face_image_generation_pipeline.py @@ -0,0 +1,79 @@ +import os +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.face_generation import stylegan2 +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input +from modelscope.preprocessors import load_image +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from ..base import Pipeline +from ..builder import PIPELINES + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_generation, module_name=Pipelines.face_image_generation) +class FaceImageGenerationPipeline(Pipeline): + + def __init__(self, model: str): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + self.size = 1024 + self.latent = 512 + self.n_mlp = 8 + self.channel_multiplier = 2 + self.truncation = 0.7 + self.truncation_mean = 4096 + self.generator = stylegan2.Generator( + self.size, + self.latent, + self.n_mlp, + channel_multiplier=self.channel_multiplier) + + self.model_file = f'{model}/{ModelFile.TORCH_MODEL_FILE}' + + self.generator.load_state_dict(torch.load(self.model_file)['g_ema']) + logger.info('load model done') + + self.mean_latent = None + if self.truncation < 1: + with torch.no_grad(): + self.mean_latent = self.generator.mean_latent( + self.truncation_mean) + + def preprocess(self, input: Input) -> Dict[str, Any]: + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + assert isinstance(input, int) + torch.manual_seed(input) + torch.cuda.manual_seed(input) + torch.cuda.manual_seed_all(input) + self.generator.eval() + with torch.no_grad(): + sample_z = torch.randn(1, self.latent) + + sample, _ = self.generator([sample_z], + truncation=self.truncation, + truncation_latent=self.mean_latent) + + sample = sample * 0.5 + 0.5 + sample = sample.squeeze(0).permute(1, 2, 0).flip(2) # RGB->BGR + sample = np.clip(sample.float().cpu().numpy(), 0, 1) * 255 + + return {OutputKeys.OUTPUT_IMG: sample.astype(np.uint8)} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/tests/pipelines/test_face_image_generation.py b/tests/pipelines/test_face_image_generation.py new file mode 100644 index 00000000..505b04c9 --- /dev/null +++ b/tests/pipelines/test_face_image_generation.py @@ -0,0 +1,37 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class FaceGenerationTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_gan_face-image-generation' + + def pipeline_inference(self, pipeline: Pipeline, seed: int): + result = pipeline(seed) + if result is not None: + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_modelhub(self): + seed = 10 + face_generation = pipeline( + Tasks.image_generation, + model=self.model_id, + ) + self.pipeline_inference(face_generation, seed) + + +if __name__ == '__main__': + unittest.main()