From fc89cf8f3ee171739e6d7628d7ba465ec9f7cfaa Mon Sep 17 00:00:00 2001 From: "tianxi.tl" Date: Tue, 2 Aug 2022 00:04:26 +0800 Subject: [PATCH] upload image_to_image_generation code Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9564875 --- data/test/images/img2img_style.jpg | 3 + modelscope/metainfo.py | 1 + modelscope/models/cv/__init__.py | 5 +- .../cv/image_to_image_generation/__init__.py | 2 + .../data/__init__.py | 24 + .../data/transforms.py | 121 ++++ .../cv/image_to_image_generation/model.py | 322 ++++++++++ .../models/__init__.py | 24 + .../models/autoencoder.py | 412 ++++++++++++ .../image_to_image_generation/models/clip.py | 418 ++++++++++++ .../image_to_image_generation/ops/__init__.py | 22 + .../ops/diffusion.py | 598 ++++++++++++++++++ .../image_to_image_generation/ops/losses.py | 35 + modelscope/pipelines/builder.py | 3 + modelscope/pipelines/cv/__init__.py | 3 + .../cv/image_to_image_generate_pipeline.py | 250 ++++++++ modelscope/utils/constant.py | 1 + .../pipelines/test_image2image_generation.py | 48 ++ 18 files changed, 2290 insertions(+), 2 deletions(-) create mode 100644 data/test/images/img2img_style.jpg create mode 100644 modelscope/models/cv/image_to_image_generation/__init__.py create mode 100644 modelscope/models/cv/image_to_image_generation/data/__init__.py create mode 100644 modelscope/models/cv/image_to_image_generation/data/transforms.py create mode 100644 modelscope/models/cv/image_to_image_generation/model.py create mode 100644 modelscope/models/cv/image_to_image_generation/models/__init__.py create mode 100644 modelscope/models/cv/image_to_image_generation/models/autoencoder.py create mode 100644 modelscope/models/cv/image_to_image_generation/models/clip.py create mode 100644 modelscope/models/cv/image_to_image_generation/ops/__init__.py create mode 100644 modelscope/models/cv/image_to_image_generation/ops/diffusion.py create mode 100644 modelscope/models/cv/image_to_image_generation/ops/losses.py create mode 100644 modelscope/pipelines/cv/image_to_image_generate_pipeline.py create mode 100644 tests/pipelines/test_image2image_generation.py diff --git a/data/test/images/img2img_style.jpg b/data/test/images/img2img_style.jpg new file mode 100644 index 00000000..1b361f11 --- /dev/null +++ b/data/test/images/img2img_style.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef06465535002fd565f3e50d16772bdcb8e47f474fb7d7c318510fff49ab1090 +size 212790 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index ec3ffc04..7713eb4f 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -91,6 +91,7 @@ class Pipelines(object): image2image_translation = 'image-to-image-translation' live_category = 'live-category' video_category = 'video-category' + image_to_image_generation = 'image-to-image-generation' # nlp tasks sentence_similarity = 'sentence-similarity' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index f5f12471..3a8a0e55 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -3,5 +3,6 @@ from . import (action_recognition, animal_recognition, cartoon, cmdssl_video_embedding, face_detection, face_generation, image_classification, image_color_enhance, image_colorization, image_denoise, image_instance_segmentation, - image_to_image_translation, object_detection, - product_retrieval_embedding, super_resolution, virual_tryon) + image_to_image_generation, image_to_image_translation, + object_detection, product_retrieval_embedding, super_resolution, + virual_tryon) diff --git a/modelscope/models/cv/image_to_image_generation/__init__.py b/modelscope/models/cv/image_to_image_generation/__init__.py new file mode 100644 index 00000000..fb408086 --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from . import data, models, ops diff --git a/modelscope/models/cv/image_to_image_generation/data/__init__.py b/modelscope/models/cv/image_to_image_generation/data/__init__.py new file mode 100644 index 00000000..33c8cf44 --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/data/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .transforms import PadToSquare + +else: + _import_structure = { + 'transforms': ['PadToSquare'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) + +# from .transforms import * # noqa F403 diff --git a/modelscope/models/cv/image_to_image_generation/data/transforms.py b/modelscope/models/cv/image_to_image_generation/data/transforms.py new file mode 100644 index 00000000..5376d813 --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/data/transforms.py @@ -0,0 +1,121 @@ +import math +import random + +import torchvision.transforms.functional as TF +from PIL import Image, ImageFilter + +__all__ = [ + 'Identity', 'PadToSquare', 'RandomScale', 'RandomRotate', + 'RandomGaussianBlur', 'RandomCrop' +] + + +class Identity(object): + + def __call__(self, *args): + if len(args) == 0: + return None + elif len(args) == 1: + return args[0] + else: + return args + + +class PadToSquare(object): + + def __init__(self, fill=(255, 255, 255)): + self.fill = fill + + def __call__(self, img): + w, h = img.size + if w != h: + if w > h: + t = (w - h) // 2 + b = w - h - t + padding = (0, t, 0, b) + else: + left = (h - w) // 2 + right = h - w - l + padding = (left, 0, right, 0) + img = TF.pad(img, padding, fill=self.fill) + return img + + +class RandomScale(object): + + def __init__(self, + min_scale=0.5, + max_scale=2.0, + min_ratio=0.8, + max_ratio=1.25): + self.min_scale = min_scale + self.max_scale = max_scale + self.min_ratio = min_ratio + self.max_ratio = max_ratio + + def __call__(self, img): + w, h = img.size + scale = 2**random.uniform( + math.log2(self.min_scale), math.log2(self.max_scale)) + ratio = 2**random.uniform( + math.log2(self.min_ratio), math.log2(self.max_ratio)) + ow = int(w * scale * math.sqrt(ratio)) + oh = int(h * scale / math.sqrt(ratio)) + img = img.resize((ow, oh), Image.BILINEAR) + return img + + +class RandomRotate(object): + + def __init__(self, + min_angle=-10.0, + max_angle=10.0, + padding=(255, 255, 255), + p=0.5): + self.min_angle = min_angle + self.max_angle = max_angle + self.padding = padding + self.p = p + + def __call__(self, img): + if random.random() < self.p: + angle = random.uniform(self.min_angle, self.max_angle) + img = img.rotate(angle, Image.BILINEAR, fillcolor=self.padding) + return img + + +class RandomGaussianBlur(object): + + def __init__(self, radius=5, p=0.5): + self.radius = radius + self.p = p + + def __call__(self, img): + if random.random() < self.p: + img = img.filter(ImageFilter.GaussianBlur(radius=self.radius)) + return img + + +class RandomCrop(object): + + def __init__(self, size, padding=(255, 255, 255)): + self.size = size + self.padding = padding + + def __call__(self, img): + # pad + w, h = img.size + pad_w = max(0, self.size - w) + pad_h = max(0, self.size - h) + if pad_w > 0 or pad_h > 0: + half_w = pad_w // 2 + half_h = pad_h // 2 + pad = (half_w, half_h, pad_w - half_w, pad_h - half_h) + img = TF.pad(img, pad, fill=self.padding) + + # crop + w, h = img.size + x1 = random.randint(0, w - self.size) + y1 = random.randint(0, h - self.size) + img = img.crop((x1, y1, x1 + self.size, y1 + self.size)) + return img diff --git a/modelscope/models/cv/image_to_image_generation/model.py b/modelscope/models/cv/image_to_image_generation/model.py new file mode 100644 index 00000000..37479b43 --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/model.py @@ -0,0 +1,322 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['UNet'] + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class Resample(nn.Module): + + def __init__(self, scale_factor=1.0): + assert scale_factor in [0.5, 1.0, 2.0] + super(Resample, self).__init__() + self.scale_factor = scale_factor + + def forward(self, x): + if self.scale_factor == 2.0: + x = F.interpolate(x, scale_factor=2, mode='nearest') + elif self.scale_factor == 0.5: + x = F.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, embed_dim, out_dim, dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.embedding = nn.Sequential(nn.SiLU(), + nn.Linear(embed_dim, out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, y): + identity = x + x = self.layer1(x) + x = x + self.embedding(y).unsqueeze(-1).unsqueeze(-1) + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class MultiHeadAttention(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=8, dropout=0.0): + assert dim % num_heads == 0 + assert context_dim is None or context_dim % num_heads == 0 + context_dim = context_dim or dim + super(MultiHeadAttention, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = math.pow(self.head_dim, -0.25) + + # layers + self.q = nn.Linear(dim, dim, bias=False) + self.k = nn.Linear(context_dim, dim, bias=False) + self.v = nn.Linear(context_dim, dim, bias=False) + self.o = nn.Linear(dim, dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None): + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # compute attention + attn = torch.einsum('binc,bjnc->bnij', q * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + attn = self.dropout(attn) + + # gather context + x = torch.einsum('bnij,bjnc->binc', attn, v) + x = x.reshape(b, -1, n * c) + + # output + x = self.o(x) + x = self.dropout(x) + return x + + +class GLU(nn.Module): + + def __init__(self, in_dim, out_dim): + super(GLU, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.proj = nn.Linear(in_dim, out_dim * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class TransformerBlock(nn.Module): + + def __init__(self, dim, context_dim, num_heads, dropout=0.0): + super(TransformerBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + # input + self.norm1 = nn.GroupNorm(32, dim, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(dim, dim, 1) + + # self attention + self.norm2 = nn.LayerNorm(dim) + self.self_attn = MultiHeadAttention(dim, None, num_heads, dropout) + + # cross attention + self.norm3 = nn.LayerNorm(dim) + self.cross_attn = MultiHeadAttention(dim, context_dim, num_heads, + dropout) + + # ffn + self.norm4 = nn.LayerNorm(dim) + self.ffn = nn.Sequential( + GLU(dim, dim * 4), nn.Dropout(dropout), nn.Linear(dim * 4, dim)) + + # output + self.conv2 = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.conv2.weight) + + def forward(self, x, context): + b, c, h, w = x.size() + identity = x + + # input + x = self.norm1(x) + x = self.conv1(x).view(b, c, -1).transpose(1, 2) + + # attention + x = x + self.self_attn(self.norm2(x)) + x = x + self.cross_attn(self.norm3(x), context) + x = x + self.ffn(self.norm4(x)) + + # output + x = x.transpose(1, 2).view(b, c, h, w) + x = self.conv2(x) + return x + identity + + +class UNet(nn.Module): + + def __init__(self, + resolution=64, + in_dim=3, + dim=192, + label_dim=512, + context_dim=512, + out_dim=3, + dim_mult=[1, 2, 3, 5], + num_heads=1, + head_dim=None, + num_res_blocks=2, + attn_scales=[1 / 2, 1 / 4, 1 / 8], + dropout=0.0): + embed_dim = dim * 4 + super(UNet, self).__init__() + self.resolution = resolution + self.in_dim = in_dim + self.dim = dim + self.context_dim = context_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_heads = num_heads + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + self.clip_embedding = nn.Sequential( + nn.Linear(label_dim, context_dim), nn.SiLU(), + nn.Linear(context_dim, context_dim)) + + # encoder + self.encoder = nn.ModuleList( + [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + block = nn.ModuleList( + [ResidualBlock(in_dim, embed_dim, out_dim, dropout)]) + if scale in attn_scales: + block.append( + TransformerBlock(out_dim, context_dim, num_heads)) + in_dim = out_dim + self.encoder.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + self.encoder.append( + nn.Conv2d(out_dim, out_dim, 3, stride=2, padding=1)) + shortcut_dims.append(out_dim) + scale /= 2.0 + + # middle + self.middle = nn.ModuleList([ + ResidualBlock(out_dim, embed_dim, out_dim, dropout), + TransformerBlock(out_dim, context_dim, num_heads), + ResidualBlock(out_dim, embed_dim, out_dim, dropout) + ]) + + # decoder + self.decoder = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual (+attention) blocks + block = nn.ModuleList([ + ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, + out_dim, dropout) + ]) + if scale in attn_scales: + block.append( + TransformerBlock(out_dim, context_dim, num_heads, + dropout)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + block.append( + nn.Sequential( + Resample(scale_factor=2.0), + nn.Conv2d(out_dim, out_dim, 3, padding=1))) + scale *= 2.0 + self.decoder.append(block) + + # head + self.head = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.head[-1].weight) + + def forward(self, x, t, y): + # embeddings + t = self.time_embedding(sinusoidal_embedding(t, self.dim)) + y = self.clip_embedding(y) + + # encoder + xs = [] + for block in self.encoder: + x = self._forward_single(block, x, t, y) + xs.append(x) + + # middle + for block in self.middle: + x = self._forward_single(block, x, t, y) + + # decoder + for block in self.decoder: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single(block, x, t, y) + + # head + x = self.head(x) + return x + + def _forward_single(self, module, x, t, y): + if isinstance(module, ResidualBlock): + x = module(x, t) + elif isinstance(module, TransformerBlock): + x = module(x, y) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, t, y) + else: + x = module(x) + return x diff --git a/modelscope/models/cv/image_to_image_generation/models/__init__.py b/modelscope/models/cv/image_to_image_generation/models/__init__.py new file mode 100644 index 00000000..ec6a46fd --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/models/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .autoencoder import VQAutoencoder + from .clip import VisionTransformer + +else: + _import_structure = { + 'autoencoder': ['VQAutoencoder'], + 'clip': ['VisionTransformer'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_to_image_generation/models/autoencoder.py b/modelscope/models/cv/image_to_image_generation/models/autoencoder.py new file mode 100644 index 00000000..181472de --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/models/autoencoder.py @@ -0,0 +1,412 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['VQAutoencoder', 'KLAutoencoder', 'PatchDiscriminator'] + + +def group_norm(dim): + return nn.GroupNorm(32, dim, eps=1e-6, affine=True) + + +class Resample(nn.Module): + + def __init__(self, dim, scale_factor): + super(Resample, self).__init__() + self.dim = dim + self.scale_factor = scale_factor + + # layers + if scale_factor == 2.0: + self.resample = nn.Sequential( + nn.Upsample(scale_factor=scale_factor, mode='nearest'), + nn.Conv2d(dim, dim, 3, padding=1)) + elif scale_factor == 0.5: + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=2, padding=0)) + else: + self.resample = nn.Identity() + + def forward(self, x): + return self.resample(x) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + group_norm(in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1), group_norm(out_dim), + nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Conv2d(in_dim, out_dim, + 1) if in_dim != out_dim else nn.Identity() + + # zero out the last layer params + nn.init.zeros_(self.residual[-1].weight) + + def forward(self, x): + return self.residual(x) + self.shortcut(x) + + +class AttentionBlock(nn.Module): + + def __init__(self, dim): + super(AttentionBlock, self).__init__() + self.dim = dim + self.scale = math.pow(dim, -0.25) + + # layers + self.norm = group_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, h, w = x.size() + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, c * 3, -1).chunk(3, dim=1) + + # compute attention + attn = torch.einsum('bci,bcj->bij', q * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.einsum('bij,bcj->bci', attn, v) + x = x.reshape(b, c, h, w) + + # output + x = self.proj(x) + return x + identity + + +class Encoder(nn.Module): + + def __init__(self, + dim=128, + z_dim=3, + dim_mult=[1, 2, 4], + num_res_blocks=2, + attn_scales=[], + dropout=0.0): + super(Encoder, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + + # params + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = nn.Conv2d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + downsamples.append(Resample(out_dim, scale_factor=0.5)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential( + group_norm(out_dim), nn.SiLU(), + nn.Conv2d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x): + x = self.conv1(x) + x = self.downsamples(x) + x = self.middle(x) + x = self.head(x) + return x + + +class Decoder(nn.Module): + + def __init__(self, + dim=128, + z_dim=3, + dim_mult=[1, 2, 4], + num_res_blocks=2, + attn_scales=[], + dropout=0.0): + super(Decoder, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + + # params + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = nn.Conv2d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + upsamples.append(Resample(out_dim, scale_factor=2.0)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + group_norm(out_dim), nn.SiLU(), + nn.Conv2d(out_dim, 3, 3, padding=1)) + + def forward(self, x): + x = self.conv1(x) + x = self.middle(x) + x = self.upsamples(x) + x = self.head(x) + return x + + +class VectorQuantizer(nn.Module): + + def __init__(self, codebook_size=8192, z_dim=3, beta=0.25): + super(VectorQuantizer, self).__init__() + self.codebook_size = codebook_size + self.z_dim = z_dim + self.beta = beta + + # init codebook + eps = math.sqrt(1.0 / codebook_size) + self.codebook = nn.Parameter( + torch.empty(codebook_size, z_dim).uniform_(-eps, eps)) + + def forward(self, z): + # preprocess + b, c, h, w = z.size() + flatten = z.permute(0, 2, 3, 1).reshape(-1, c) + + # quantization + with torch.no_grad(): + tokens = torch.cdist(flatten, self.codebook).argmin(dim=1) + quantized = F.embedding(tokens, + self.codebook).view(b, h, w, + c).permute(0, 3, 1, 2) + + # compute loss + codebook_loss = F.mse_loss(quantized, z.detach()) + commitment_loss = F.mse_loss(quantized.detach(), z) + loss = codebook_loss + self.beta * commitment_loss + + # perplexity + counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype) + # dist.all_reduce(counts) + p = counts / counts.sum() + perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10))) + + # postprocess + tokens = tokens.view(b, h, w) + quantized = z + (quantized - z).detach() + return quantized, tokens, loss, perplexity + + +class VQAutoencoder(nn.Module): + + def __init__(self, + dim=128, + z_dim=3, + dim_mult=[1, 2, 4], + num_res_blocks=2, + attn_scales=[], + dropout=0.0, + codebook_size=8192, + beta=0.25): + super(VQAutoencoder, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.codebook_size = codebook_size + self.beta = beta + + # blocks + self.encoder = Encoder(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, dropout) + self.conv1 = nn.Conv2d(z_dim, z_dim, 1) + self.quantizer = VectorQuantizer(codebook_size, z_dim, beta) + self.conv2 = nn.Conv2d(z_dim, z_dim, 1) + self.decoder = Decoder(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, dropout) + + def forward(self, x): + z = self.encoder(x) + z = self.conv1(z) + z, tokens, loss, perplexity = self.quantizer(z) + z = self.conv2(z) + x = self.decoder(z) + return x, tokens, loss, perplexity + + def encode(self, imgs): + z = self.encoder(imgs) + z = self.conv1(z) + return z + + def decode(self, z): + r"""Absort the quantizer in the decoder. + """ + z = self.quantizer(z)[0] + z = self.conv2(z) + imgs = self.decoder(z) + return imgs + + @torch.no_grad() + def encode_to_tokens(self, imgs): + # preprocess + z = self.encoder(imgs) + z = self.conv1(z) + + # quantization + b, c, h, w = z.size() + flatten = z.permute(0, 2, 3, 1).reshape(-1, c) + tokens = torch.cdist(flatten, self.quantizer.codebook).argmin(dim=1) + return tokens.view(b, -1) + + @torch.no_grad() + def decode_from_tokens(self, tokens): + # dequantization + z = F.embedding(tokens, self.quantizer.codebook) + + # postprocess + b, l, c = z.size() + h = w = int(math.sqrt(l)) + z = z.view(b, h, w, c).permute(0, 3, 1, 2) + z = self.conv2(z) + imgs = self.decoder(z) + return imgs + + +class KLAutoencoder(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + dropout=0.0): + super(KLAutoencoder, self).__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + + # blocks + self.encoder = Encoder(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, dropout) + self.conv1 = nn.Conv2d(z_dim * 2, z_dim * 2, 1) + self.conv2 = nn.Conv2d(z_dim, z_dim, 1) + self.decoder = Decoder(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x = self.decode(z) + return x, mu, log_var + + def encode(self, x): + x = self.encoder(x) + mu, log_var = self.conv1(x).chunk(2, dim=1) + return mu, log_var + + def decode(self, z): + x = self.conv2(z) + x = self.decoder(x) + return x + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + +class PatchDiscriminator(nn.Module): + + def __init__(self, in_dim=3, dim=64, num_layers=3): + super(PatchDiscriminator, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.num_layers = num_layers + + # params + dims = [dim * min(8, 2**u) for u in range(num_layers + 1)] + + # layers + layers = [ + nn.Conv2d(in_dim, dim, 4, stride=2, padding=1), + nn.LeakyReLU(0.2) + ] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + stride = 1 if i == num_layers - 1 else 2 + layers += [ + nn.Conv2d( + in_dim, out_dim, 4, stride=stride, padding=1, bias=False), + nn.BatchNorm2d(out_dim), + nn.LeakyReLU(0.2) + ] + layers += [nn.Conv2d(out_dim, 1, 4, stride=1, padding=1)] + self.layers = nn.Sequential(*layers) + + # initialize weights + self.apply(self.init_weights) + + def forward(self, x): + return self.layers(x) + + def init_weights(self, m): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0.0, 0.02) + elif isinstance(m, nn.BatchNorm2d): + nn.init.normal_(m.weight, 1.0, 0.02) + nn.init.zeros_(m.bias) diff --git a/modelscope/models/cv/image_to_image_generation/models/clip.py b/modelscope/models/cv/image_to_image_generation/models/clip.py new file mode 100644 index 00000000..35d9d882 --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/models/clip.py @@ -0,0 +1,418 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import modelscope.models.cv.image_to_image_translation.ops as ops # for using differentiable all_gather + +__all__ = [ + 'CLIP', 'clip_vit_b_32', 'clip_vit_b_16', 'clip_vit_l_14', + 'clip_vit_l_14_336px', 'clip_vit_h_16' +] + + +def to_fp16(m): + if isinstance(m, (nn.Linear, nn.Conv2d)): + m.weight.data = m.weight.data.half() + if m.bias is not None: + m.bias.data = m.bias.data.half() + elif hasattr(m, 'head'): + p = getattr(m, 'head') + p.data = p.data.half() + + +class QuickGELU(nn.Module): + + def forward(self, x): + return x * torch.sigmoid(1.702 * x) + + +class LayerNorm(nn.LayerNorm): + r"""Subclass of nn.LayerNorm to handle fp16. + """ + + def forward(self, x): + return super(LayerNorm, self).forward(x.float()).type_as(x) + + +class SelfAttention(nn.Module): + + def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): + assert dim % num_heads == 0 + super(SelfAttention, self).__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = 1.0 / math.sqrt(self.head_dim) + + # layers + self.to_qkv = nn.Linear(dim, dim * 3) + self.attn_dropout = nn.Dropout(attn_dropout) + self.proj = nn.Linear(dim, dim) + self.proj_dropout = nn.Dropout(proj_dropout) + + def forward(self, x, mask=None): + r"""x: [B, L, C]. + mask: [*, L, L]. + """ + b, l, _, n = *x.size(), self.num_heads + + # compute query, key, and value + q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1) + q = q.reshape(l, b * n, -1).transpose(0, 1) + k = k.reshape(l, b * n, -1).transpose(0, 1) + v = v.reshape(l, b * n, -1).transpose(0, 1) + + # compute attention + attn = self.scale * torch.bmm(q, k.transpose(1, 2)) + if mask is not None: + attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf')) + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + attn = self.attn_dropout(attn) + + # gather context + x = torch.bmm(attn, v) + x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1) + + # output + x = self.proj(x) + x = self.proj_dropout(x) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): + super(AttentionBlock, self).__init__() + self.dim = dim + self.num_heads = num_heads + + # layers + self.norm1 = LayerNorm(dim) + self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout) + self.norm2 = LayerNorm(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim), + nn.Dropout(proj_dropout)) + + def forward(self, x, mask=None): + x = x + self.attn(self.norm1(x), mask) + x = x + self.mlp(self.norm2(x)) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size=224, + patch_size=16, + dim=768, + out_dim=512, + num_heads=12, + num_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + assert image_size % patch_size == 0 + super(VisionTransformer, self).__init__() + self.image_size = image_size + self.patch_size = patch_size + self.dim = dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.num_patches = (image_size // patch_size)**2 + + # embeddings + gain = 1.0 / math.sqrt(dim) + self.patch_embedding = nn.Conv2d( + 3, dim, kernel_size=patch_size, stride=patch_size, bias=False) + self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) + self.pos_embedding = nn.Parameter( + gain * torch.randn(1, self.num_patches + 1, dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.pre_norm = LayerNorm(dim) + self.transformer = nn.Sequential(*[ + AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) + for _ in range(num_layers) + ]) + self.post_norm = LayerNorm(dim) + + # head + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + + def forward(self, x): + b, dtype = x.size(0), self.head.dtype + x = x.type(dtype) + + # patch-embedding + x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c] + x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x], + dim=1) + x = self.dropout(x + self.pos_embedding.type(dtype)) + x = self.pre_norm(x) + + # transformer + x = self.transformer(x) + + # head + x = self.post_norm(x) + x = torch.mm(x[:, 0, :], self.head) + return x + + def fp16(self): + return self.apply(to_fp16) + + +class TextTransformer(nn.Module): + + def __init__(self, + vocab_size, + text_len, + dim=512, + out_dim=512, + num_heads=8, + num_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + super(TextTransformer, self).__init__() + self.vocab_size = vocab_size + self.text_len = text_len + self.dim = dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + + # embeddings + self.token_embedding = nn.Embedding(vocab_size, dim) + self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim)) + self.dropout = nn.Dropout(embedding_dropout) + + # transformer + self.transformer = nn.ModuleList([ + AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) + for _ in range(num_layers) + ]) + self.norm = LayerNorm(dim) + + # head + gain = 1.0 / math.sqrt(dim) + self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) + + # causal attention mask + self.register_buffer('attn_mask', + torch.tril(torch.ones(1, text_len, text_len))) + + def forward(self, x): + eot, dtype = x.argmax(dim=-1), self.head.dtype + + # embeddings + x = self.dropout( + self.token_embedding(x).type(dtype) + + self.pos_embedding.type(dtype)) + + # transformer + for block in self.transformer: + x = block(x, self.attn_mask) + + # head + x = self.norm(x) + x = torch.mm(x[torch.arange(x.size(0)), eot], self.head) + return x + + def fp16(self): + return self.apply(to_fp16) + + +class CLIP(nn.Module): + + def __init__(self, + embed_dim=512, + image_size=224, + patch_size=16, + vision_dim=768, + vision_heads=12, + vision_layers=12, + vocab_size=49408, + text_len=77, + text_dim=512, + text_heads=8, + text_layers=12, + attn_dropout=0.0, + proj_dropout=0.0, + embedding_dropout=0.0): + super(CLIP, self).__init__() + self.embed_dim = embed_dim + self.image_size = image_size + self.patch_size = patch_size + self.vision_dim = vision_dim + self.vision_heads = vision_heads + self.vision_layers = vision_layers + self.vocab_size = vocab_size + self.text_len = text_len + self.text_dim = text_dim + self.text_heads = text_heads + self.text_layers = text_layers + + # models + self.visual = VisionTransformer( + image_size=image_size, + patch_size=patch_size, + dim=vision_dim, + out_dim=embed_dim, + num_heads=vision_heads, + num_layers=vision_layers, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout) + self.textual = TextTransformer( + vocab_size=vocab_size, + text_len=text_len, + dim=text_dim, + out_dim=embed_dim, + num_heads=text_heads, + num_layers=text_layers, + attn_dropout=attn_dropout, + proj_dropout=proj_dropout, + embedding_dropout=embedding_dropout) + self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) + + def forward(self, imgs, txt_tokens): + r"""imgs: [B, C, H, W] of torch.float32. + txt_tokens: [B, T] of torch.long. + """ + xi = self.visual(imgs) + xt = self.textual(txt_tokens) + + # normalize features + xi = F.normalize(xi, p=2, dim=1) + xt = F.normalize(xt, p=2, dim=1) + + # gather features from all ranks + full_xi = ops.diff_all_gather(xi) + full_xt = ops.diff_all_gather(xt) + + # logits + scale = self.log_scale.exp() + logits_i2t = scale * torch.mm(xi, full_xt.t()) + logits_t2i = scale * torch.mm(xt, full_xi.t()) + + # labels + labels = torch.arange( + len(xi) * ops.get_rank(), + len(xi) * (ops.get_rank() + 1), + dtype=torch.long, + device=xi.device) + return logits_i2t, logits_t2i, labels + + def init_weights(self): + # embeddings + nn.init.normal_(self.textual.token_embedding.weight, std=0.02) + nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1) + + # attentions + for modality in ['visual', 'textual']: + dim = self.vision_dim if modality == 'visual' else 'textual' + transformer = getattr(self, modality).transformer + proj_gain = (1.0 / math.sqrt(dim)) * ( + 1.0 / math.sqrt(2 * transformer.num_layers)) + attn_gain = 1.0 / math.sqrt(dim) + mlp_gain = 1.0 / math.sqrt(2.0 * dim) + for block in transformer.layers: + nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain) + nn.init.normal_(block.attn.proj.weight, std=proj_gain) + nn.init.normal_(block.mlp[0].weight, std=mlp_gain) + nn.init.normal_(block.mlp[2].weight, std=proj_gain) + + def param_groups(self): + groups = [{ + 'params': [ + p for n, p in self.named_parameters() + if 'norm' in n or n.endswith('bias') + ], + 'weight_decay': + 0.0 + }, { + 'params': [ + p for n, p in self.named_parameters() + if not ('norm' in n or n.endswith('bias')) + ] + }] + return groups + + def fp16(self): + return self.apply(to_fp16) + + +def clip_vit_b_32(**kwargs): + return CLIP( + embed_dim=512, + image_size=224, + patch_size=32, + vision_dim=768, + vision_heads=12, + vision_layers=12, + text_dim=512, + text_heads=8, + text_layers=12, + **kwargs) + + +def clip_vit_b_16(**kwargs): + return CLIP( + embed_dim=512, + image_size=224, + patch_size=16, + vision_dim=768, + vision_heads=12, + vision_layers=12, + text_dim=512, + text_heads=8, + text_layers=12, + **kwargs) + + +def clip_vit_l_14(**kwargs): + return CLIP( + embed_dim=768, + image_size=224, + patch_size=14, + vision_dim=1024, + vision_heads=16, + vision_layers=24, + text_dim=768, + text_heads=12, + text_layers=12, + **kwargs) + + +def clip_vit_l_14_336px(**kwargs): + return CLIP( + embed_dim=768, + image_size=336, + patch_size=14, + vision_dim=1024, + vision_heads=16, + vision_layers=24, + text_dim=768, + text_heads=12, + text_layers=12, + **kwargs) + + +def clip_vit_h_16(**kwargs): + return CLIP( + embed_dim=1024, + image_size=256, + patch_size=16, + vision_dim=1280, + vision_heads=16, + vision_layers=32, + text_dim=1024, + text_heads=16, + text_layers=24, + **kwargs) diff --git a/modelscope/models/cv/image_to_image_generation/ops/__init__.py b/modelscope/models/cv/image_to_image_generation/ops/__init__.py new file mode 100644 index 00000000..49674b49 --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/ops/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .diffusion import GaussianDiffusion, beta_schedule + +else: + _import_structure = { + 'diffusion': ['GaussianDiffusion', 'beta_schedule'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_to_image_generation/ops/diffusion.py b/modelscope/models/cv/image_to_image_generation/ops/diffusion.py new file mode 100644 index 00000000..bcbb6402 --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/ops/diffusion.py @@ -0,0 +1,598 @@ +import math + +import torch + +from .losses import discretized_gaussian_log_likelihood, kl_divergence + +__all__ = ['GaussianDiffusion', 'beta_schedule'] + + +def _i(tensor, t, x): + r"""Index tensor using t and format the output according to x. + """ + shape = (x.size(0), ) + (1, ) * (x.ndim - 1) + return tensor[t].view(shape).to(x) + + +def beta_schedule(schedule, + num_timesteps=1000, + init_beta=None, + last_beta=None): + if schedule == 'linear': + scale = 1000.0 / num_timesteps + init_beta = init_beta or scale * 0.0001 + last_beta = last_beta or scale * 0.02 + return torch.linspace( + init_beta, last_beta, num_timesteps, dtype=torch.float64) + elif schedule == 'quadratic': + init_beta = init_beta or 0.0015 + last_beta = last_beta or 0.0195 + return torch.linspace( + init_beta**0.5, last_beta**0.5, num_timesteps, + dtype=torch.float64)**2 + elif schedule == 'cosine': + betas = [] + for step in range(num_timesteps): + t1 = step / num_timesteps + t2 = (step + 1) / num_timesteps + + # fn = lambda u: math.cos((u + 0.008) / 1.008 * math.pi / 2)**2 + def fn(u): + return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2 + + betas.append(min(1.0 - fn(t2) / fn(t1), 0.999)) + return torch.tensor(betas, dtype=torch.float64) + else: + raise ValueError(f'Unsupported schedule: {schedule}') + + +class GaussianDiffusion(object): + + def __init__(self, + betas, + mean_type='eps', + var_type='learned_range', + loss_type='mse', + rescale_timesteps=False): + # check input + if not isinstance(betas, torch.DoubleTensor): + betas = torch.tensor(betas, dtype=torch.float64) + assert min(betas) > 0 and max(betas) <= 1 + assert mean_type in ['x0', 'x_{t-1}', 'eps'] + assert var_type in [ + 'learned', 'learned_range', 'fixed_large', 'fixed_small' + ] + assert loss_type in [ + 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1' + ] + self.betas = betas + self.num_timesteps = len(betas) + self.mean_type = mean_type + self.var_type = var_type + self.loss_type = loss_type + self.rescale_timesteps = rescale_timesteps + + # alphas + alphas = 1 - self.betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + self.alphas_cumprod_prev = torch.cat( + [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) + self.alphas_cumprod_next = torch.cat( + [self.alphas_cumprod[1:], + alphas.new_zeros([1])]) + + # q(x_t | x_{t-1}) + self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 + - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = torch.log(1.0 + - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod + - 1) + + # q(x_{t-1} | x_t, x_0) + self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( + 1.0 - self.alphas_cumprod) + self.posterior_log_variance_clipped = torch.log( + self.posterior_variance.clamp(1e-20)) + self.posterior_mean_coef1 = betas * torch.sqrt( + self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + self.posterior_mean_coef2 = ( + 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( + 1.0 - self.alphas_cumprod) + + def q_sample(self, x0, t, noise=None): + r"""Sample from q(x_t | x_0). + """ + noise = torch.randn_like(x0) if noise is None else noise + return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + _i( + self.sqrt_one_minus_alphas_cumprod, t, x0) * noise + + def q_mean_variance(self, x0, t): + r"""Distribution of q(x_t | x_0). + """ + mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 + var = _i(1.0 - self.alphas_cumprod, t, x0) + log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) + return mu, var, log_var + + def q_posterior_mean_variance(self, x0, xt, t): + r"""Distribution of q(x_{t-1} | x_t, x_0). + """ + mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( + self.posterior_mean_coef2, t, xt) * xt + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + return mu, var, log_var + + @torch.no_grad() + def p_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t). + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + # predict distribution of p(x_{t-1} | x_t) + mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, + guide_scale) + + # random sample (with optional conditional function) + noise = torch.randn_like(xt) + # no noise when t == 0 + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + if condition_fn is not None: + grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) + mu = mu.float() + var * grad.float() + xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise + return xt_1, x0 + + @torch.no_grad() + def p_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None): + r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). + """ + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process + for step in torch.arange(self.num_timesteps).flip(0): + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale) + return xt + + def p_mean_variance(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None): + r"""Distribution of p(x_{t-1} | x_t). + """ + # predict distribution + if guide_scale is None: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + else: + # classifier-free guidance + # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) + assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 + assert self.mean_type == 'eps' + y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) + u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) + out = torch.cat( + [ + u_out[:, :3] + guide_scale * # noqa W504 + (y_out[:, :3] - u_out[:, :3]), + y_out[:, 3:] + ], + dim=1) + + # compute variance + if self.var_type == 'learned': + out, log_var = out.chunk(2, dim=1) + var = torch.exp(log_var) + elif self.var_type == 'learned_range': + out, fraction = out.chunk(2, dim=1) + min_log_var = _i(self.posterior_log_variance_clipped, t, xt) + max_log_var = _i(torch.log(self.betas), t, xt) + fraction = (fraction + 1) / 2.0 + log_var = fraction * max_log_var + (1 - fraction) * min_log_var + var = torch.exp(log_var) + elif self.var_type == 'fixed_large': + var = _i( + torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, + xt) + log_var = torch.log(var) + elif self.var_type == 'fixed_small': + var = _i(self.posterior_variance, t, xt) + log_var = _i(self.posterior_log_variance_clipped, t, xt) + + # compute mean and x0 + if self.mean_type == 'x_{t-1}': + mu = out # x_{t-1} + x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - _i( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, + xt) * xt + elif self.mean_type == 'x0': + x0 = out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + elif self.mean_type == 'eps': + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * out + mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) + + # restrict the range of x0 + if percentile is not None: + assert percentile > 0 and percentile <= 1 # e.g., 0.995 + s = torch.quantile( + x0.flatten(1).abs(), percentile, + dim=1).clamp_(1.0).view(-1, 1, 1, 1) + x0 = torch.min(s, torch.max(-s, x0)) / s + elif clamp is not None: + x0 = x0.clamp(-clamp, clamp) + return mu, var, log_var, x0 + + @torch.no_grad() + def ddim_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas = _i(self.alphas_cumprod, t, xt) + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + sigmas = eta * torch.sqrt((1 - alphas_prev) / # noqa W504 + (1 - alphas) * # noqa W504 + (1 - alphas / alphas_prev)) + + # random sample + noise = torch.randn_like(xt) + direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps + mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise + return xt_1, x0 + + @torch.no_grad() + def ddim_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + ddim_timesteps=20, + eta=0.0): + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, guide_scale, + ddim_timesteps, eta) + return xt + + @torch.no_grad() + def ddim_reverse_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). + """ + stride = self.num_timesteps // ddim_timesteps + + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, + percentile, guide_scale) + + # derive variables + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + alphas_next = _i( + torch.cat( + [self.alphas_cumprod, + self.alphas_cumprod.new_zeros([1])]), + (t + stride).clamp(0, self.num_timesteps), xt) + + # reverse sample + mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps + return mu, x0 + + @torch.no_grad() + def ddim_reverse_sample_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None, + guide_scale=None, + ddim_timesteps=20): + # prepare input + b, c, h, w = x0.size() + xt = x0 + + # reconstruction steps + steps = torch.arange(0, self.num_timesteps, + self.num_timesteps // ddim_timesteps) + for step in steps: + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, + percentile, guide_scale, + ddim_timesteps) + return xt + + @torch.no_grad() + def plms_sample(self, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + r"""Sample from p(x_{t-1} | x_t) using PLMS. + - condition_fn: for classifier-based guidance (guided-diffusion). + - guide_scale: for classifier-free guidance (glide/dalle-2). + """ + stride = self.num_timesteps // plms_timesteps + + # function for compute eps + def compute_eps(xt, t): + # predict distribution of p(x_{t-1} | x_t) + _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile, guide_scale) + + # condition + if condition_fn is not None: + # x0 -> eps + alpha = _i(self.alphas_cumprod, t, xt) + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt + - x0) / _i(self.sqrt_recipm1_alphas_cumprod, t, xt) + eps = eps - (1 - alpha).sqrt() * condition_fn( + xt, self._scale_timesteps(t), **model_kwargs) + + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # derive eps + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + return eps + + # function for compute x_0 and x_{t-1} + def compute_x0(eps, t): + # eps -> x0 + x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) * eps + + # deterministic sample + alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) + direction = torch.sqrt(1 - alphas_prev) * eps + # mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) + xt_1 = torch.sqrt(alphas_prev) * x0 + direction + return xt_1, x0 + + # PLMS sample + eps = compute_eps(xt, t) + if len(eps_cache) == 0: + # 2nd order pseudo improved Euler + xt_1, x0 = compute_x0(eps, t) + eps_next = compute_eps(xt_1, (t - stride).clamp(0)) + eps_prime = (eps + eps_next) / 2.0 + elif len(eps_cache) == 1: + # 2nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (3 * eps - eps_cache[-1]) / 2.0 + elif len(eps_cache) == 2: + # 3nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (23 * eps - 16 * eps_cache[-1] + + 5 * eps_cache[-2]) / 12.0 + elif len(eps_cache) >= 3: + # 4nd order pseudo linear multistep (Adams-Bashforth) + eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] + - 9 * eps_cache[-3]) / 24.0 + xt_1, x0 = compute_x0(eps_prime, t) + return xt_1, x0, eps + + @torch.no_grad() + def plms_sample_loop(self, + noise, + model, + model_kwargs={}, + clamp=None, + percentile=None, + condition_fn=None, + guide_scale=None, + plms_timesteps=20): + # prepare input + b, c, h, w = noise.size() + xt = noise + + # diffusion process + steps = (1 + torch.arange(0, self.num_timesteps, + self.num_timesteps // plms_timesteps)).clamp( + 0, self.num_timesteps - 1).flip(0) + eps_cache = [] + for step in steps: + # PLMS sampling step + t = torch.full((b, ), step, dtype=torch.long, device=xt.device) + xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, + percentile, condition_fn, + guide_scale, plms_timesteps, + eps_cache) + + # update eps cache + eps_cache.append(eps) + if len(eps_cache) >= 4: + eps_cache.pop(0) + return xt + + def loss(self, x0, t, model, model_kwargs={}, noise=None): + noise = torch.randn_like(x0) if noise is None else noise + xt = self.q_sample(x0, t, noise=noise) + + # compute loss + if self.loss_type in ['kl', 'rescaled_kl']: + loss, _ = self.variational_lower_bound(x0, xt, t, model, + model_kwargs) + if self.loss_type == 'rescaled_kl': + loss = loss * self.num_timesteps + elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: + out = model(xt, self._scale_timesteps(t), **model_kwargs) + + # VLB for variation + loss_vlb = 0.0 + if self.var_type in ['learned', 'learned_range']: + out, var = out.chunk(2, dim=1) + frozen = torch.cat([ + out.detach(), var + ], dim=1) # learn var without affecting the prediction of mean + loss_vlb, _ = self.variational_lower_bound( + x0, xt, t, model=lambda *args, **kwargs: frozen) + if self.loss_type.startswith('rescaled_'): + loss_vlb = loss_vlb * self.num_timesteps / 1000.0 + + # MSE/L1 for x0/eps + target = { + 'eps': noise, + 'x0': x0, + 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0] + }[self.mean_type] + loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2 + ).abs().flatten(1).mean(dim=1) + + # total loss + loss = loss + loss_vlb + return loss + + def variational_lower_bound(self, + x0, + xt, + t, + model, + model_kwargs={}, + clamp=None, + percentile=None): + # compute groundtruth and predicted distributions + mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) + mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, + clamp, percentile) + + # compute KL loss + kl = kl_divergence(mu1, log_var1, mu2, log_var2) + kl = kl.flatten(1).mean(dim=1) / math.log(2.0) + + # compute discretized NLL loss (for p(x0 | x1) only) + nll = -discretized_gaussian_log_likelihood( + x0, mean=mu2, log_scale=0.5 * log_var2) + nll = nll.flatten(1).mean(dim=1) / math.log(2.0) + + # NLL for p(x0 | x1) and KL otherwise + vlb = torch.where(t == 0, nll, kl) + return vlb, x0 + + @torch.no_grad() + def variational_lower_bound_loop(self, + x0, + model, + model_kwargs={}, + clamp=None, + percentile=None): + r"""Compute the entire variational lower bound, measured in bits-per-dim. + """ + # prepare input and output + b, c, h, w = x0.size() + metrics = {'vlb': [], 'mse': [], 'x0_mse': []} + + # loop + for step in torch.arange(self.num_timesteps).flip(0): + # compute VLB + t = torch.full((b, ), step, dtype=torch.long, device=x0.device) + noise = torch.randn_like(x0) + xt = self.q_sample(x0, t, noise) + vlb, pred_x0 = self.variational_lower_bound( + x0, xt, t, model, model_kwargs, clamp, percentile) + + # predict eps from x0 + eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( + self.sqrt_recipm1_alphas_cumprod, t, xt) + + # collect metrics + metrics['vlb'].append(vlb) + metrics['x0_mse'].append( + (pred_x0 - x0).square().flatten(1).mean(dim=1)) + metrics['mse'].append( + (eps - noise).square().flatten(1).mean(dim=1)) + metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} + + # compute the prior KL term for VLB, measured in bits-per-dim + mu, _, log_var = self.q_mean_variance(x0, t) + kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), + torch.zeros_like(log_var)) + kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) + + # update metrics + metrics['prior_bits_per_dim'] = kl_prior + metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior + return metrics + + def _scale_timesteps(self, t): + if self.rescale_timesteps: + return t.float() * 1000.0 / self.num_timesteps + return t diff --git a/modelscope/models/cv/image_to_image_generation/ops/losses.py b/modelscope/models/cv/image_to_image_generation/ops/losses.py new file mode 100644 index 00000000..23e8d246 --- /dev/null +++ b/modelscope/models/cv/image_to_image_generation/ops/losses.py @@ -0,0 +1,35 @@ +import math + +import torch + +__all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood'] + + +def kl_divergence(mu1, logvar1, mu2, logvar2): + return 0.5 * ( + -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + # noqa W504 + ((mu1 - mu2)**2) * torch.exp(-logvar2)) + + +def standard_normal_cdf(x): + r"""A fast approximation of the cumulative distribution function of the standard normal. + """ + return 0.5 * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +def discretized_gaussian_log_likelihood(x0, mean, log_scale): + assert x0.shape == mean.shape == log_scale.shape + cx = x0 - mean + inv_stdv = torch.exp(-log_scale) + cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) + cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) + log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = torch.where( + x0 < -0.999, log_cdf_plus, + torch.where(x0 > 0.999, log_one_minus_cdf_min, + torch.log(cdf_delta.clamp(min=1e-12)))) + assert log_probs.shape == x0.shape + return log_probs diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 50652ac1..204a7208 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -118,6 +118,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.image_classification_dailylife: (Pipelines.daily_image_classification, 'damo/cv_vit-base_image-classification_Dailylife-labels'), + Tasks.image_to_image_generation: + (Pipelines.image_to_image_generation, + 'damo/cv_latent_diffusion_image2image_generate'), } diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index e66176e4..c2fb9eaa 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -20,6 +20,7 @@ if TYPE_CHECKING: from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline from .image_matting_pipeline import ImageMattingPipeline from .image_super_resolution_pipeline import ImageSuperResolutionPipeline + from .image_to_image_generation_pipeline import Image2ImageGenerationePipeline from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline from .style_transfer_pipeline import StyleTransferPipeline @@ -51,6 +52,8 @@ else: 'product_retrieval_embedding_pipeline': ['ProductRetrievalEmbeddingPipeline'], 'live_category_pipeline': ['LiveCategoryPipeline'], + 'image_to_image_generation_pipeline': + ['Image2ImageGenerationePipeline'], 'ocr_detection_pipeline': ['OCRDetectionPipeline'], 'style_transfer_pipeline': ['StyleTransferPipeline'], 'video_category_pipeline': ['VideoCategoryPipeline'], diff --git a/modelscope/pipelines/cv/image_to_image_generate_pipeline.py b/modelscope/pipelines/cv/image_to_image_generate_pipeline.py new file mode 100644 index 00000000..6533a14c --- /dev/null +++ b/modelscope/pipelines/cv/image_to_image_generate_pipeline.py @@ -0,0 +1,250 @@ +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch +import torch.nn.functional as F +import torchvision.transforms as T +import torchvision.transforms.functional as TF +from PIL import Image +from torchvision.utils import save_image + +import modelscope.models.cv.image_to_image_generation.data as data +import modelscope.models.cv.image_to_image_generation.models as models +import modelscope.models.cv.image_to_image_generation.ops as ops +from modelscope.metainfo import Pipelines +from modelscope.models.cv.image_to_image_generation.model import UNet +from modelscope.models.cv.image_to_image_generation.models.clip import \ + VisionTransformer +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_to_image_generation, + module_name=Pipelines.image_to_image_generation) +class Image2ImageGenerationePipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image-to-image generation pipeline + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model) + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading config from {config_path}') + self.cfg = Config.from_file(config_path) + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.repetition = 4 + # load vit model + vit_model_path = osp.join(self.model, + self.cfg.ModelPath.vit_model_path) + logger.info(f'loading vit model from {vit_model_path}') + self.vit = VisionTransformer( + image_size=self.cfg.Params.vit.vit_image_size, + patch_size=self.cfg.Params.vit.vit_patch_size, + dim=self.cfg.Params.vit.vit_dim, + out_dim=self.cfg.Params.vit.vit_out_dim, + num_heads=self.cfg.Params.vit.vit_num_heads, + num_layers=self.cfg.Params.vit.vit_num_layers).eval( + ).requires_grad_(False).to(self._device) # noqa E123 + state = torch.load(vit_model_path) + state = { + k[len('visual.'):]: v + for k, v in state.items() if k.startswith('visual.') + } + self.vit.load_state_dict(state) + logger.info('load vit model done') + + # load autoencoder model + ae_model_path = osp.join(self.model, self.cfg.ModelPath.ae_model_path) + logger.info(f'loading autoencoder model from {ae_model_path}') + self.autoencoder = models.VQAutoencoder( + dim=self.cfg.Params.ae.ae_dim, + z_dim=self.cfg.Params.ae.ae_z_dim, + dim_mult=self.cfg.Params.ae.ae_dim_mult, + attn_scales=self.cfg.Params.ae.ae_attn_scales, + codebook_size=self.cfg.Params.ae.ae_codebook_size).eval( + ).requires_grad_(False).to(self._device) # noqa E123 + self.autoencoder.load_state_dict( + torch.load(ae_model_path, map_location=self._device)) + logger.info('load autoencoder model done') + + # load decoder model + decoder_model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + logger.info(f'loading decoder model from {decoder_model_path}') + self.decoder = UNet( + resolution=self.cfg.Params.unet.unet_resolution, + in_dim=self.cfg.Params.unet.unet_in_dim, + dim=self.cfg.Params.unet.unet_dim, + label_dim=self.cfg.Params.vit.vit_out_dim, + context_dim=self.cfg.Params.unet.unet_context_dim, + out_dim=self.cfg.Params.unet.unet_out_dim, + dim_mult=self.cfg.Params.unet.unet_dim_mult, + num_heads=self.cfg.Params.unet.unet_num_heads, + head_dim=None, + num_res_blocks=self.cfg.Params.unet.unet_res_blocks, + attn_scales=self.cfg.Params.unet.unet_attn_scales, + dropout=self.cfg.Params.unet.unet_dropout).eval().requires_grad_( + False).to(self._device) + self.decoder.load_state_dict( + torch.load(decoder_model_path, map_location=self._device)) + logger.info('load decoder model done') + + # diffusion + logger.info('Initialization diffusion ...') + betas = ops.beta_schedule(self.cfg.Params.diffusion.schedule, + self.cfg.Params.diffusion.num_timesteps) + self.diffusion = ops.GaussianDiffusion( + betas=betas, + mean_type=self.cfg.Params.diffusion.mean_type, + var_type=self.cfg.Params.diffusion.var_type, + loss_type=self.cfg.Params.diffusion.loss_type, + rescale_timesteps=False) + + def preprocess(self, input: Input) -> Dict[str, Any]: + input_img_list = [] + if isinstance(input, str): + input_img_list = [input] + input_type = 0 + elif isinstance(input, tuple) and len(input) == 2: + input_img_list = list(input) + input_type = 1 + else: + raise TypeError( + 'modelscope error: Only support "str" or "tuple (img1, img2)" , but got {type(input)}' + ) + + if input_type == 0: + logger.info('Processing Similar Image Generation mode') + if input_type == 1: + logger.info('Processing Interpolation mode') + + img_list = [] + for i, input_img in enumerate(input_img_list): + img = LoadImage.convert_to_img(input_img) + logger.info(f'Load {i}-th image done') + img_list.append(img) + + transforms = T.Compose([ + data.PadToSquare(), + T.Resize( + self.cfg.DATA.scale_size, + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=self.cfg.DATA.mean, std=self.cfg.DATA.std) + ]) + + y_list = [] + for img in img_list: + img = transforms(img) + imgs = torch.unsqueeze(img, 0) + imgs = imgs.to(self._device) + imgs_x0 = self.autoencoder.encode(imgs) + b, c, h, w = imgs_x0.shape + aug_imgs = TF.normalize( + F.interpolate( + imgs.add(1).div(2), (self.cfg.Params.vit.vit_image_size, + self.cfg.Params.vit.vit_image_size), + mode='bilinear', + align_corners=True), self.cfg.Params.vit.vit_mean, + self.cfg.Params.vit.vit_std) + uy = self.vit(aug_imgs) + y = F.normalize(uy, p=2, dim=1) + y_list.append(y) + + if input_type == 0: + result = { + 'image_data': y_list[0], + 'c': c, + 'h': h, + 'w': w, + 'type': input_type + } + elif input_type == 1: + result = { + 'image_data': y_list[0], + 'image_data_s': y_list[1], + 'c': c, + 'h': h, + 'w': w, + 'type': input_type + } + return result + + @torch.no_grad() + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + type_ = input['type'] + if type_ == 0: + # Similar Image Generation # + y = input['image_data'] + + # fix seed + torch.manual_seed(1 * 8888) + torch.cuda.manual_seed(1 * 8888) + i_y = y.repeat(self.repetition, 1) + + # sample images + x0 = self.diffusion.ddim_sample_loop( + noise=torch.randn(self.repetition, input['c'], input['h'], + input['w']).to(self._device), + model=self.decoder, + model_kwargs=[{ + 'y': i_y + }, { + 'y': torch.zeros_like(i_y) + }], + guide_scale=1.0, + clamp=None, + ddim_timesteps=50, + eta=1.0) + i_gen_imgs = self.autoencoder.decode(x0) + return {OutputKeys.OUTPUT_IMG: i_gen_imgs} + else: + # Interpolation # + # get content-style pairs + y = input['image_data'] + y_s = input['image_data_s'] + + # fix seed + torch.manual_seed(1 * 8888) + torch.cuda.manual_seed(1 * 8888) + noise = torch.randn(self.repetition, input['c'], input['h'], + input['w']).to(self._device) + + # interpolation between y_cid and y_sid + factors = torch.linspace(0, 1, self.repetition).unsqueeze(1).to( + self._device) + i_y = (1 - factors) * y + factors * y_s + + # sample images + x0 = self.diffusion.ddim_sample_loop( + noise=noise, + model=self.decoder, + model_kwargs=[{ + 'y': i_y + }, { + 'y': torch.zeros_like(i_y) + }], + guide_scale=3.0, + clamp=None, + ddim_timesteps=50, + eta=0.0) + i_gen_imgs = self.autoencoder.decode(x0) + return {OutputKeys.OUTPUT_IMG: i_gen_imgs} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 87a282ca..eb4bd5b6 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -42,6 +42,7 @@ class CVTasks(object): video_category = 'video-category' image_classification_imagenet = 'image-classification-imagenet' image_classification_dailylife = 'image-classification-dailylife' + image_to_image_generation = 'image-to-image-generation' class NLPTasks(object): diff --git a/tests/pipelines/test_image2image_generation.py b/tests/pipelines/test_image2image_generation.py new file mode 100644 index 00000000..dceb61c6 --- /dev/null +++ b/tests/pipelines/test_image2image_generation.py @@ -0,0 +1,48 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import shutil +import unittest + +from torchvision.utils import save_image + +from modelscope.fileio import File +from modelscope.msdatasets import MsDataset +from modelscope.pipelines import pipeline +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.test_utils import test_level + + +class Image2ImageGenerationTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + r"""We provide two generation modes, i.e., Similar Image Generation and Interpolation. + You can pass the following parameters for different mode. + 1. Similar Image Generation Mode: + 2. Interpolation Mode: + """ + img2img_gen_pipeline = pipeline( + Tasks.image_to_image_generation, + model='damo/cv_latent_diffusion_image2image_generate') + + # Similar Image Generation mode + result1 = img2img_gen_pipeline('data/test/images/img2img_input.jpg') + # Interpolation Mode + result2 = img2img_gen_pipeline(('data/test/images/img2img_input.jpg', + 'data/test/images/img2img_style.jpg')) + save_image( + result1['output_img'].clamp(-1, 1), + 'result1.jpg', + range=(-1, 1), + normalize=True, + nrow=4) + save_image( + result2['output_img'].clamp(-1, 1), + 'result2.jpg', + range=(-1, 1), + normalize=True, + nrow=4) + + +if __name__ == '__main__': + unittest.main()