Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9564875master
| @@ -0,0 +1,3 @@ | |||||
| version https://git-lfs.github.com/spec/v1 | |||||
| oid sha256:ef06465535002fd565f3e50d16772bdcb8e47f474fb7d7c318510fff49ab1090 | |||||
| size 212790 | |||||
| @@ -91,6 +91,7 @@ class Pipelines(object): | |||||
| image2image_translation = 'image-to-image-translation' | image2image_translation = 'image-to-image-translation' | ||||
| live_category = 'live-category' | live_category = 'live-category' | ||||
| video_category = 'video-category' | video_category = 'video-category' | ||||
| image_to_image_generation = 'image-to-image-generation' | |||||
| # nlp tasks | # nlp tasks | ||||
| sentence_similarity = 'sentence-similarity' | sentence_similarity = 'sentence-similarity' | ||||
| @@ -3,5 +3,6 @@ from . import (action_recognition, animal_recognition, cartoon, | |||||
| cmdssl_video_embedding, face_detection, face_generation, | cmdssl_video_embedding, face_detection, face_generation, | ||||
| image_classification, image_color_enhance, image_colorization, | image_classification, image_color_enhance, image_colorization, | ||||
| image_denoise, image_instance_segmentation, | image_denoise, image_instance_segmentation, | ||||
| image_to_image_translation, object_detection, | |||||
| product_retrieval_embedding, super_resolution, virual_tryon) | |||||
| image_to_image_generation, image_to_image_translation, | |||||
| object_detection, product_retrieval_embedding, super_resolution, | |||||
| virual_tryon) | |||||
| @@ -0,0 +1,2 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from . import data, models, ops | |||||
| @@ -0,0 +1,24 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .transforms import PadToSquare | |||||
| else: | |||||
| _import_structure = { | |||||
| 'transforms': ['PadToSquare'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| # from .transforms import * # noqa F403 | |||||
| @@ -0,0 +1,121 @@ | |||||
| import math | |||||
| import random | |||||
| import torchvision.transforms.functional as TF | |||||
| from PIL import Image, ImageFilter | |||||
| __all__ = [ | |||||
| 'Identity', 'PadToSquare', 'RandomScale', 'RandomRotate', | |||||
| 'RandomGaussianBlur', 'RandomCrop' | |||||
| ] | |||||
| class Identity(object): | |||||
| def __call__(self, *args): | |||||
| if len(args) == 0: | |||||
| return None | |||||
| elif len(args) == 1: | |||||
| return args[0] | |||||
| else: | |||||
| return args | |||||
| class PadToSquare(object): | |||||
| def __init__(self, fill=(255, 255, 255)): | |||||
| self.fill = fill | |||||
| def __call__(self, img): | |||||
| w, h = img.size | |||||
| if w != h: | |||||
| if w > h: | |||||
| t = (w - h) // 2 | |||||
| b = w - h - t | |||||
| padding = (0, t, 0, b) | |||||
| else: | |||||
| left = (h - w) // 2 | |||||
| right = h - w - l | |||||
| padding = (left, 0, right, 0) | |||||
| img = TF.pad(img, padding, fill=self.fill) | |||||
| return img | |||||
| class RandomScale(object): | |||||
| def __init__(self, | |||||
| min_scale=0.5, | |||||
| max_scale=2.0, | |||||
| min_ratio=0.8, | |||||
| max_ratio=1.25): | |||||
| self.min_scale = min_scale | |||||
| self.max_scale = max_scale | |||||
| self.min_ratio = min_ratio | |||||
| self.max_ratio = max_ratio | |||||
| def __call__(self, img): | |||||
| w, h = img.size | |||||
| scale = 2**random.uniform( | |||||
| math.log2(self.min_scale), math.log2(self.max_scale)) | |||||
| ratio = 2**random.uniform( | |||||
| math.log2(self.min_ratio), math.log2(self.max_ratio)) | |||||
| ow = int(w * scale * math.sqrt(ratio)) | |||||
| oh = int(h * scale / math.sqrt(ratio)) | |||||
| img = img.resize((ow, oh), Image.BILINEAR) | |||||
| return img | |||||
| class RandomRotate(object): | |||||
| def __init__(self, | |||||
| min_angle=-10.0, | |||||
| max_angle=10.0, | |||||
| padding=(255, 255, 255), | |||||
| p=0.5): | |||||
| self.min_angle = min_angle | |||||
| self.max_angle = max_angle | |||||
| self.padding = padding | |||||
| self.p = p | |||||
| def __call__(self, img): | |||||
| if random.random() < self.p: | |||||
| angle = random.uniform(self.min_angle, self.max_angle) | |||||
| img = img.rotate(angle, Image.BILINEAR, fillcolor=self.padding) | |||||
| return img | |||||
| class RandomGaussianBlur(object): | |||||
| def __init__(self, radius=5, p=0.5): | |||||
| self.radius = radius | |||||
| self.p = p | |||||
| def __call__(self, img): | |||||
| if random.random() < self.p: | |||||
| img = img.filter(ImageFilter.GaussianBlur(radius=self.radius)) | |||||
| return img | |||||
| class RandomCrop(object): | |||||
| def __init__(self, size, padding=(255, 255, 255)): | |||||
| self.size = size | |||||
| self.padding = padding | |||||
| def __call__(self, img): | |||||
| # pad | |||||
| w, h = img.size | |||||
| pad_w = max(0, self.size - w) | |||||
| pad_h = max(0, self.size - h) | |||||
| if pad_w > 0 or pad_h > 0: | |||||
| half_w = pad_w // 2 | |||||
| half_h = pad_h // 2 | |||||
| pad = (half_w, half_h, pad_w - half_w, pad_h - half_h) | |||||
| img = TF.pad(img, pad, fill=self.padding) | |||||
| # crop | |||||
| w, h = img.size | |||||
| x1 = random.randint(0, w - self.size) | |||||
| y1 = random.randint(0, h - self.size) | |||||
| img = img.crop((x1, y1, x1 + self.size, y1 + self.size)) | |||||
| return img | |||||
| @@ -0,0 +1,322 @@ | |||||
| import math | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| __all__ = ['UNet'] | |||||
| def sinusoidal_embedding(timesteps, dim): | |||||
| # check input | |||||
| half = dim // 2 | |||||
| timesteps = timesteps.float() | |||||
| # compute sinusoidal embedding | |||||
| sinusoid = torch.outer( | |||||
| timesteps, torch.pow(10000, | |||||
| -torch.arange(half).to(timesteps).div(half))) | |||||
| x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) | |||||
| if dim % 2 != 0: | |||||
| x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) | |||||
| return x | |||||
| class Resample(nn.Module): | |||||
| def __init__(self, scale_factor=1.0): | |||||
| assert scale_factor in [0.5, 1.0, 2.0] | |||||
| super(Resample, self).__init__() | |||||
| self.scale_factor = scale_factor | |||||
| def forward(self, x): | |||||
| if self.scale_factor == 2.0: | |||||
| x = F.interpolate(x, scale_factor=2, mode='nearest') | |||||
| elif self.scale_factor == 0.5: | |||||
| x = F.avg_pool2d(x, kernel_size=2, stride=2) | |||||
| return x | |||||
| class ResidualBlock(nn.Module): | |||||
| def __init__(self, in_dim, embed_dim, out_dim, dropout=0.0): | |||||
| super(ResidualBlock, self).__init__() | |||||
| self.in_dim = in_dim | |||||
| self.embed_dim = embed_dim | |||||
| self.out_dim = out_dim | |||||
| # layers | |||||
| self.layer1 = nn.Sequential( | |||||
| nn.GroupNorm(32, in_dim), nn.SiLU(), | |||||
| nn.Conv2d(in_dim, out_dim, 3, padding=1)) | |||||
| self.embedding = nn.Sequential(nn.SiLU(), | |||||
| nn.Linear(embed_dim, out_dim)) | |||||
| self.layer2 = nn.Sequential( | |||||
| nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), | |||||
| nn.Conv2d(out_dim, out_dim, 3, padding=1)) | |||||
| self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( | |||||
| in_dim, out_dim, 1) | |||||
| # zero out the last layer params | |||||
| nn.init.zeros_(self.layer2[-1].weight) | |||||
| def forward(self, x, y): | |||||
| identity = x | |||||
| x = self.layer1(x) | |||||
| x = x + self.embedding(y).unsqueeze(-1).unsqueeze(-1) | |||||
| x = self.layer2(x) | |||||
| x = x + self.shortcut(identity) | |||||
| return x | |||||
| class MultiHeadAttention(nn.Module): | |||||
| def __init__(self, dim, context_dim=None, num_heads=8, dropout=0.0): | |||||
| assert dim % num_heads == 0 | |||||
| assert context_dim is None or context_dim % num_heads == 0 | |||||
| context_dim = context_dim or dim | |||||
| super(MultiHeadAttention, self).__init__() | |||||
| self.dim = dim | |||||
| self.context_dim = context_dim | |||||
| self.num_heads = num_heads | |||||
| self.head_dim = dim // num_heads | |||||
| self.scale = math.pow(self.head_dim, -0.25) | |||||
| # layers | |||||
| self.q = nn.Linear(dim, dim, bias=False) | |||||
| self.k = nn.Linear(context_dim, dim, bias=False) | |||||
| self.v = nn.Linear(context_dim, dim, bias=False) | |||||
| self.o = nn.Linear(dim, dim) | |||||
| self.dropout = nn.Dropout(dropout) | |||||
| def forward(self, x, context=None): | |||||
| # check inputs | |||||
| context = x if context is None else context | |||||
| b, n, c = x.size(0), self.num_heads, self.head_dim | |||||
| # compute query, key, value | |||||
| q = self.q(x).view(b, -1, n, c) | |||||
| k = self.k(context).view(b, -1, n, c) | |||||
| v = self.v(context).view(b, -1, n, c) | |||||
| # compute attention | |||||
| attn = torch.einsum('binc,bjnc->bnij', q * self.scale, k * self.scale) | |||||
| attn = F.softmax(attn, dim=-1) | |||||
| attn = self.dropout(attn) | |||||
| # gather context | |||||
| x = torch.einsum('bnij,bjnc->binc', attn, v) | |||||
| x = x.reshape(b, -1, n * c) | |||||
| # output | |||||
| x = self.o(x) | |||||
| x = self.dropout(x) | |||||
| return x | |||||
| class GLU(nn.Module): | |||||
| def __init__(self, in_dim, out_dim): | |||||
| super(GLU, self).__init__() | |||||
| self.in_dim = in_dim | |||||
| self.out_dim = out_dim | |||||
| self.proj = nn.Linear(in_dim, out_dim * 2) | |||||
| def forward(self, x): | |||||
| x, gate = self.proj(x).chunk(2, dim=-1) | |||||
| return x * F.gelu(gate) | |||||
| class TransformerBlock(nn.Module): | |||||
| def __init__(self, dim, context_dim, num_heads, dropout=0.0): | |||||
| super(TransformerBlock, self).__init__() | |||||
| self.dim = dim | |||||
| self.context_dim = context_dim | |||||
| self.num_heads = num_heads | |||||
| self.head_dim = dim // num_heads | |||||
| # input | |||||
| self.norm1 = nn.GroupNorm(32, dim, eps=1e-6, affine=True) | |||||
| self.conv1 = nn.Conv2d(dim, dim, 1) | |||||
| # self attention | |||||
| self.norm2 = nn.LayerNorm(dim) | |||||
| self.self_attn = MultiHeadAttention(dim, None, num_heads, dropout) | |||||
| # cross attention | |||||
| self.norm3 = nn.LayerNorm(dim) | |||||
| self.cross_attn = MultiHeadAttention(dim, context_dim, num_heads, | |||||
| dropout) | |||||
| # ffn | |||||
| self.norm4 = nn.LayerNorm(dim) | |||||
| self.ffn = nn.Sequential( | |||||
| GLU(dim, dim * 4), nn.Dropout(dropout), nn.Linear(dim * 4, dim)) | |||||
| # output | |||||
| self.conv2 = nn.Conv2d(dim, dim, 1) | |||||
| # zero out the last layer params | |||||
| nn.init.zeros_(self.conv2.weight) | |||||
| def forward(self, x, context): | |||||
| b, c, h, w = x.size() | |||||
| identity = x | |||||
| # input | |||||
| x = self.norm1(x) | |||||
| x = self.conv1(x).view(b, c, -1).transpose(1, 2) | |||||
| # attention | |||||
| x = x + self.self_attn(self.norm2(x)) | |||||
| x = x + self.cross_attn(self.norm3(x), context) | |||||
| x = x + self.ffn(self.norm4(x)) | |||||
| # output | |||||
| x = x.transpose(1, 2).view(b, c, h, w) | |||||
| x = self.conv2(x) | |||||
| return x + identity | |||||
| class UNet(nn.Module): | |||||
| def __init__(self, | |||||
| resolution=64, | |||||
| in_dim=3, | |||||
| dim=192, | |||||
| label_dim=512, | |||||
| context_dim=512, | |||||
| out_dim=3, | |||||
| dim_mult=[1, 2, 3, 5], | |||||
| num_heads=1, | |||||
| head_dim=None, | |||||
| num_res_blocks=2, | |||||
| attn_scales=[1 / 2, 1 / 4, 1 / 8], | |||||
| dropout=0.0): | |||||
| embed_dim = dim * 4 | |||||
| super(UNet, self).__init__() | |||||
| self.resolution = resolution | |||||
| self.in_dim = in_dim | |||||
| self.dim = dim | |||||
| self.context_dim = context_dim | |||||
| self.out_dim = out_dim | |||||
| self.dim_mult = dim_mult | |||||
| self.num_heads = num_heads | |||||
| self.head_dim = head_dim | |||||
| self.num_res_blocks = num_res_blocks | |||||
| self.attn_scales = attn_scales | |||||
| # params | |||||
| enc_dims = [dim * u for u in [1] + dim_mult] | |||||
| dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] | |||||
| shortcut_dims = [] | |||||
| scale = 1.0 | |||||
| # embeddings | |||||
| self.time_embedding = nn.Sequential( | |||||
| nn.Linear(dim, embed_dim), nn.SiLU(), | |||||
| nn.Linear(embed_dim, embed_dim)) | |||||
| self.clip_embedding = nn.Sequential( | |||||
| nn.Linear(label_dim, context_dim), nn.SiLU(), | |||||
| nn.Linear(context_dim, context_dim)) | |||||
| # encoder | |||||
| self.encoder = nn.ModuleList( | |||||
| [nn.Conv2d(self.in_dim, dim, 3, padding=1)]) | |||||
| shortcut_dims.append(dim) | |||||
| for i, (in_dim, | |||||
| out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): | |||||
| for j in range(num_res_blocks): | |||||
| # residual (+attention) blocks | |||||
| block = nn.ModuleList( | |||||
| [ResidualBlock(in_dim, embed_dim, out_dim, dropout)]) | |||||
| if scale in attn_scales: | |||||
| block.append( | |||||
| TransformerBlock(out_dim, context_dim, num_heads)) | |||||
| in_dim = out_dim | |||||
| self.encoder.append(block) | |||||
| shortcut_dims.append(out_dim) | |||||
| # downsample | |||||
| if i != len(dim_mult) - 1 and j == num_res_blocks - 1: | |||||
| self.encoder.append( | |||||
| nn.Conv2d(out_dim, out_dim, 3, stride=2, padding=1)) | |||||
| shortcut_dims.append(out_dim) | |||||
| scale /= 2.0 | |||||
| # middle | |||||
| self.middle = nn.ModuleList([ | |||||
| ResidualBlock(out_dim, embed_dim, out_dim, dropout), | |||||
| TransformerBlock(out_dim, context_dim, num_heads), | |||||
| ResidualBlock(out_dim, embed_dim, out_dim, dropout) | |||||
| ]) | |||||
| # decoder | |||||
| self.decoder = nn.ModuleList() | |||||
| for i, (in_dim, | |||||
| out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): | |||||
| for j in range(num_res_blocks + 1): | |||||
| # residual (+attention) blocks | |||||
| block = nn.ModuleList([ | |||||
| ResidualBlock(in_dim + shortcut_dims.pop(), embed_dim, | |||||
| out_dim, dropout) | |||||
| ]) | |||||
| if scale in attn_scales: | |||||
| block.append( | |||||
| TransformerBlock(out_dim, context_dim, num_heads, | |||||
| dropout)) | |||||
| in_dim = out_dim | |||||
| # upsample | |||||
| if i != len(dim_mult) - 1 and j == num_res_blocks: | |||||
| block.append( | |||||
| nn.Sequential( | |||||
| Resample(scale_factor=2.0), | |||||
| nn.Conv2d(out_dim, out_dim, 3, padding=1))) | |||||
| scale *= 2.0 | |||||
| self.decoder.append(block) | |||||
| # head | |||||
| self.head = nn.Sequential( | |||||
| nn.GroupNorm(32, out_dim), nn.SiLU(), | |||||
| nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) | |||||
| # zero out the last layer params | |||||
| nn.init.zeros_(self.head[-1].weight) | |||||
| def forward(self, x, t, y): | |||||
| # embeddings | |||||
| t = self.time_embedding(sinusoidal_embedding(t, self.dim)) | |||||
| y = self.clip_embedding(y) | |||||
| # encoder | |||||
| xs = [] | |||||
| for block in self.encoder: | |||||
| x = self._forward_single(block, x, t, y) | |||||
| xs.append(x) | |||||
| # middle | |||||
| for block in self.middle: | |||||
| x = self._forward_single(block, x, t, y) | |||||
| # decoder | |||||
| for block in self.decoder: | |||||
| x = torch.cat([x, xs.pop()], dim=1) | |||||
| x = self._forward_single(block, x, t, y) | |||||
| # head | |||||
| x = self.head(x) | |||||
| return x | |||||
| def _forward_single(self, module, x, t, y): | |||||
| if isinstance(module, ResidualBlock): | |||||
| x = module(x, t) | |||||
| elif isinstance(module, TransformerBlock): | |||||
| x = module(x, y) | |||||
| elif isinstance(module, nn.ModuleList): | |||||
| for block in module: | |||||
| x = self._forward_single(block, x, t, y) | |||||
| else: | |||||
| x = module(x) | |||||
| return x | |||||
| @@ -0,0 +1,24 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .autoencoder import VQAutoencoder | |||||
| from .clip import VisionTransformer | |||||
| else: | |||||
| _import_structure = { | |||||
| 'autoencoder': ['VQAutoencoder'], | |||||
| 'clip': ['VisionTransformer'] | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,412 @@ | |||||
| import math | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| __all__ = ['VQAutoencoder', 'KLAutoencoder', 'PatchDiscriminator'] | |||||
| def group_norm(dim): | |||||
| return nn.GroupNorm(32, dim, eps=1e-6, affine=True) | |||||
| class Resample(nn.Module): | |||||
| def __init__(self, dim, scale_factor): | |||||
| super(Resample, self).__init__() | |||||
| self.dim = dim | |||||
| self.scale_factor = scale_factor | |||||
| # layers | |||||
| if scale_factor == 2.0: | |||||
| self.resample = nn.Sequential( | |||||
| nn.Upsample(scale_factor=scale_factor, mode='nearest'), | |||||
| nn.Conv2d(dim, dim, 3, padding=1)) | |||||
| elif scale_factor == 0.5: | |||||
| self.resample = nn.Sequential( | |||||
| nn.ZeroPad2d((0, 1, 0, 1)), | |||||
| nn.Conv2d(dim, dim, 3, stride=2, padding=0)) | |||||
| else: | |||||
| self.resample = nn.Identity() | |||||
| def forward(self, x): | |||||
| return self.resample(x) | |||||
| class ResidualBlock(nn.Module): | |||||
| def __init__(self, in_dim, out_dim, dropout=0.0): | |||||
| super(ResidualBlock, self).__init__() | |||||
| self.in_dim = in_dim | |||||
| self.out_dim = out_dim | |||||
| # layers | |||||
| self.residual = nn.Sequential( | |||||
| group_norm(in_dim), nn.SiLU(), | |||||
| nn.Conv2d(in_dim, out_dim, 3, padding=1), group_norm(out_dim), | |||||
| nn.SiLU(), nn.Dropout(dropout), | |||||
| nn.Conv2d(out_dim, out_dim, 3, padding=1)) | |||||
| self.shortcut = nn.Conv2d(in_dim, out_dim, | |||||
| 1) if in_dim != out_dim else nn.Identity() | |||||
| # zero out the last layer params | |||||
| nn.init.zeros_(self.residual[-1].weight) | |||||
| def forward(self, x): | |||||
| return self.residual(x) + self.shortcut(x) | |||||
| class AttentionBlock(nn.Module): | |||||
| def __init__(self, dim): | |||||
| super(AttentionBlock, self).__init__() | |||||
| self.dim = dim | |||||
| self.scale = math.pow(dim, -0.25) | |||||
| # layers | |||||
| self.norm = group_norm(dim) | |||||
| self.to_qkv = nn.Conv2d(dim, dim * 3, 1) | |||||
| self.proj = nn.Conv2d(dim, dim, 1) | |||||
| # zero out the last layer params | |||||
| nn.init.zeros_(self.proj.weight) | |||||
| def forward(self, x): | |||||
| identity = x | |||||
| b, c, h, w = x.size() | |||||
| # compute query, key, value | |||||
| x = self.norm(x) | |||||
| q, k, v = self.to_qkv(x).view(b, c * 3, -1).chunk(3, dim=1) | |||||
| # compute attention | |||||
| attn = torch.einsum('bci,bcj->bij', q * self.scale, k * self.scale) | |||||
| attn = F.softmax(attn, dim=-1) | |||||
| # gather context | |||||
| x = torch.einsum('bij,bcj->bci', attn, v) | |||||
| x = x.reshape(b, c, h, w) | |||||
| # output | |||||
| x = self.proj(x) | |||||
| return x + identity | |||||
| class Encoder(nn.Module): | |||||
| def __init__(self, | |||||
| dim=128, | |||||
| z_dim=3, | |||||
| dim_mult=[1, 2, 4], | |||||
| num_res_blocks=2, | |||||
| attn_scales=[], | |||||
| dropout=0.0): | |||||
| super(Encoder, self).__init__() | |||||
| self.dim = dim | |||||
| self.z_dim = z_dim | |||||
| self.dim_mult = dim_mult | |||||
| self.num_res_blocks = num_res_blocks | |||||
| self.attn_scales = attn_scales | |||||
| # params | |||||
| dims = [dim * u for u in [1] + dim_mult] | |||||
| scale = 1.0 | |||||
| # init block | |||||
| self.conv1 = nn.Conv2d(3, dims[0], 3, padding=1) | |||||
| # downsample blocks | |||||
| downsamples = [] | |||||
| for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): | |||||
| # residual (+attention) blocks | |||||
| for _ in range(num_res_blocks): | |||||
| downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) | |||||
| if scale in attn_scales: | |||||
| downsamples.append(AttentionBlock(out_dim)) | |||||
| in_dim = out_dim | |||||
| # downsample block | |||||
| if i != len(dim_mult) - 1: | |||||
| downsamples.append(Resample(out_dim, scale_factor=0.5)) | |||||
| scale /= 2.0 | |||||
| self.downsamples = nn.Sequential(*downsamples) | |||||
| # middle blocks | |||||
| self.middle = nn.Sequential( | |||||
| ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), | |||||
| ResidualBlock(out_dim, out_dim, dropout)) | |||||
| # output blocks | |||||
| self.head = nn.Sequential( | |||||
| group_norm(out_dim), nn.SiLU(), | |||||
| nn.Conv2d(out_dim, z_dim, 3, padding=1)) | |||||
| def forward(self, x): | |||||
| x = self.conv1(x) | |||||
| x = self.downsamples(x) | |||||
| x = self.middle(x) | |||||
| x = self.head(x) | |||||
| return x | |||||
| class Decoder(nn.Module): | |||||
| def __init__(self, | |||||
| dim=128, | |||||
| z_dim=3, | |||||
| dim_mult=[1, 2, 4], | |||||
| num_res_blocks=2, | |||||
| attn_scales=[], | |||||
| dropout=0.0): | |||||
| super(Decoder, self).__init__() | |||||
| self.dim = dim | |||||
| self.z_dim = z_dim | |||||
| self.dim_mult = dim_mult | |||||
| self.num_res_blocks = num_res_blocks | |||||
| self.attn_scales = attn_scales | |||||
| # params | |||||
| dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] | |||||
| scale = 1.0 / 2**(len(dim_mult) - 2) | |||||
| # init block | |||||
| self.conv1 = nn.Conv2d(z_dim, dims[0], 3, padding=1) | |||||
| # middle blocks | |||||
| self.middle = nn.Sequential( | |||||
| ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), | |||||
| ResidualBlock(dims[0], dims[0], dropout)) | |||||
| # upsample blocks | |||||
| upsamples = [] | |||||
| for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): | |||||
| # residual (+attention) blocks | |||||
| for _ in range(num_res_blocks + 1): | |||||
| upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) | |||||
| if scale in attn_scales: | |||||
| upsamples.append(AttentionBlock(out_dim)) | |||||
| in_dim = out_dim | |||||
| # upsample block | |||||
| if i != len(dim_mult) - 1: | |||||
| upsamples.append(Resample(out_dim, scale_factor=2.0)) | |||||
| scale *= 2.0 | |||||
| self.upsamples = nn.Sequential(*upsamples) | |||||
| # output blocks | |||||
| self.head = nn.Sequential( | |||||
| group_norm(out_dim), nn.SiLU(), | |||||
| nn.Conv2d(out_dim, 3, 3, padding=1)) | |||||
| def forward(self, x): | |||||
| x = self.conv1(x) | |||||
| x = self.middle(x) | |||||
| x = self.upsamples(x) | |||||
| x = self.head(x) | |||||
| return x | |||||
| class VectorQuantizer(nn.Module): | |||||
| def __init__(self, codebook_size=8192, z_dim=3, beta=0.25): | |||||
| super(VectorQuantizer, self).__init__() | |||||
| self.codebook_size = codebook_size | |||||
| self.z_dim = z_dim | |||||
| self.beta = beta | |||||
| # init codebook | |||||
| eps = math.sqrt(1.0 / codebook_size) | |||||
| self.codebook = nn.Parameter( | |||||
| torch.empty(codebook_size, z_dim).uniform_(-eps, eps)) | |||||
| def forward(self, z): | |||||
| # preprocess | |||||
| b, c, h, w = z.size() | |||||
| flatten = z.permute(0, 2, 3, 1).reshape(-1, c) | |||||
| # quantization | |||||
| with torch.no_grad(): | |||||
| tokens = torch.cdist(flatten, self.codebook).argmin(dim=1) | |||||
| quantized = F.embedding(tokens, | |||||
| self.codebook).view(b, h, w, | |||||
| c).permute(0, 3, 1, 2) | |||||
| # compute loss | |||||
| codebook_loss = F.mse_loss(quantized, z.detach()) | |||||
| commitment_loss = F.mse_loss(quantized.detach(), z) | |||||
| loss = codebook_loss + self.beta * commitment_loss | |||||
| # perplexity | |||||
| counts = F.one_hot(tokens, self.codebook_size).sum(dim=0).to(z.dtype) | |||||
| # dist.all_reduce(counts) | |||||
| p = counts / counts.sum() | |||||
| perplexity = torch.exp(-torch.sum(p * torch.log(p + 1e-10))) | |||||
| # postprocess | |||||
| tokens = tokens.view(b, h, w) | |||||
| quantized = z + (quantized - z).detach() | |||||
| return quantized, tokens, loss, perplexity | |||||
| class VQAutoencoder(nn.Module): | |||||
| def __init__(self, | |||||
| dim=128, | |||||
| z_dim=3, | |||||
| dim_mult=[1, 2, 4], | |||||
| num_res_blocks=2, | |||||
| attn_scales=[], | |||||
| dropout=0.0, | |||||
| codebook_size=8192, | |||||
| beta=0.25): | |||||
| super(VQAutoencoder, self).__init__() | |||||
| self.dim = dim | |||||
| self.z_dim = z_dim | |||||
| self.dim_mult = dim_mult | |||||
| self.num_res_blocks = num_res_blocks | |||||
| self.attn_scales = attn_scales | |||||
| self.codebook_size = codebook_size | |||||
| self.beta = beta | |||||
| # blocks | |||||
| self.encoder = Encoder(dim, z_dim, dim_mult, num_res_blocks, | |||||
| attn_scales, dropout) | |||||
| self.conv1 = nn.Conv2d(z_dim, z_dim, 1) | |||||
| self.quantizer = VectorQuantizer(codebook_size, z_dim, beta) | |||||
| self.conv2 = nn.Conv2d(z_dim, z_dim, 1) | |||||
| self.decoder = Decoder(dim, z_dim, dim_mult, num_res_blocks, | |||||
| attn_scales, dropout) | |||||
| def forward(self, x): | |||||
| z = self.encoder(x) | |||||
| z = self.conv1(z) | |||||
| z, tokens, loss, perplexity = self.quantizer(z) | |||||
| z = self.conv2(z) | |||||
| x = self.decoder(z) | |||||
| return x, tokens, loss, perplexity | |||||
| def encode(self, imgs): | |||||
| z = self.encoder(imgs) | |||||
| z = self.conv1(z) | |||||
| return z | |||||
| def decode(self, z): | |||||
| r"""Absort the quantizer in the decoder. | |||||
| """ | |||||
| z = self.quantizer(z)[0] | |||||
| z = self.conv2(z) | |||||
| imgs = self.decoder(z) | |||||
| return imgs | |||||
| @torch.no_grad() | |||||
| def encode_to_tokens(self, imgs): | |||||
| # preprocess | |||||
| z = self.encoder(imgs) | |||||
| z = self.conv1(z) | |||||
| # quantization | |||||
| b, c, h, w = z.size() | |||||
| flatten = z.permute(0, 2, 3, 1).reshape(-1, c) | |||||
| tokens = torch.cdist(flatten, self.quantizer.codebook).argmin(dim=1) | |||||
| return tokens.view(b, -1) | |||||
| @torch.no_grad() | |||||
| def decode_from_tokens(self, tokens): | |||||
| # dequantization | |||||
| z = F.embedding(tokens, self.quantizer.codebook) | |||||
| # postprocess | |||||
| b, l, c = z.size() | |||||
| h = w = int(math.sqrt(l)) | |||||
| z = z.view(b, h, w, c).permute(0, 3, 1, 2) | |||||
| z = self.conv2(z) | |||||
| imgs = self.decoder(z) | |||||
| return imgs | |||||
| class KLAutoencoder(nn.Module): | |||||
| def __init__(self, | |||||
| dim=128, | |||||
| z_dim=4, | |||||
| dim_mult=[1, 2, 4, 4], | |||||
| num_res_blocks=2, | |||||
| attn_scales=[], | |||||
| dropout=0.0): | |||||
| super(KLAutoencoder, self).__init__() | |||||
| self.dim = dim | |||||
| self.z_dim = z_dim | |||||
| self.dim_mult = dim_mult | |||||
| self.num_res_blocks = num_res_blocks | |||||
| self.attn_scales = attn_scales | |||||
| # blocks | |||||
| self.encoder = Encoder(dim, z_dim * 2, dim_mult, num_res_blocks, | |||||
| attn_scales, dropout) | |||||
| self.conv1 = nn.Conv2d(z_dim * 2, z_dim * 2, 1) | |||||
| self.conv2 = nn.Conv2d(z_dim, z_dim, 1) | |||||
| self.decoder = Decoder(dim, z_dim, dim_mult, num_res_blocks, | |||||
| attn_scales, dropout) | |||||
| def forward(self, x): | |||||
| mu, log_var = self.encode(x) | |||||
| z = self.reparameterize(mu, log_var) | |||||
| x = self.decode(z) | |||||
| return x, mu, log_var | |||||
| def encode(self, x): | |||||
| x = self.encoder(x) | |||||
| mu, log_var = self.conv1(x).chunk(2, dim=1) | |||||
| return mu, log_var | |||||
| def decode(self, z): | |||||
| x = self.conv2(z) | |||||
| x = self.decoder(x) | |||||
| return x | |||||
| def reparameterize(self, mu, log_var): | |||||
| std = torch.exp(0.5 * log_var) | |||||
| eps = torch.randn_like(std) | |||||
| return eps * std + mu | |||||
| class PatchDiscriminator(nn.Module): | |||||
| def __init__(self, in_dim=3, dim=64, num_layers=3): | |||||
| super(PatchDiscriminator, self).__init__() | |||||
| self.in_dim = in_dim | |||||
| self.dim = dim | |||||
| self.num_layers = num_layers | |||||
| # params | |||||
| dims = [dim * min(8, 2**u) for u in range(num_layers + 1)] | |||||
| # layers | |||||
| layers = [ | |||||
| nn.Conv2d(in_dim, dim, 4, stride=2, padding=1), | |||||
| nn.LeakyReLU(0.2) | |||||
| ] | |||||
| for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): | |||||
| stride = 1 if i == num_layers - 1 else 2 | |||||
| layers += [ | |||||
| nn.Conv2d( | |||||
| in_dim, out_dim, 4, stride=stride, padding=1, bias=False), | |||||
| nn.BatchNorm2d(out_dim), | |||||
| nn.LeakyReLU(0.2) | |||||
| ] | |||||
| layers += [nn.Conv2d(out_dim, 1, 4, stride=1, padding=1)] | |||||
| self.layers = nn.Sequential(*layers) | |||||
| # initialize weights | |||||
| self.apply(self.init_weights) | |||||
| def forward(self, x): | |||||
| return self.layers(x) | |||||
| def init_weights(self, m): | |||||
| if isinstance(m, nn.Conv2d): | |||||
| nn.init.normal_(m.weight, 0.0, 0.02) | |||||
| elif isinstance(m, nn.BatchNorm2d): | |||||
| nn.init.normal_(m.weight, 1.0, 0.02) | |||||
| nn.init.zeros_(m.bias) | |||||
| @@ -0,0 +1,418 @@ | |||||
| import math | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| import modelscope.models.cv.image_to_image_translation.ops as ops # for using differentiable all_gather | |||||
| __all__ = [ | |||||
| 'CLIP', 'clip_vit_b_32', 'clip_vit_b_16', 'clip_vit_l_14', | |||||
| 'clip_vit_l_14_336px', 'clip_vit_h_16' | |||||
| ] | |||||
| def to_fp16(m): | |||||
| if isinstance(m, (nn.Linear, nn.Conv2d)): | |||||
| m.weight.data = m.weight.data.half() | |||||
| if m.bias is not None: | |||||
| m.bias.data = m.bias.data.half() | |||||
| elif hasattr(m, 'head'): | |||||
| p = getattr(m, 'head') | |||||
| p.data = p.data.half() | |||||
| class QuickGELU(nn.Module): | |||||
| def forward(self, x): | |||||
| return x * torch.sigmoid(1.702 * x) | |||||
| class LayerNorm(nn.LayerNorm): | |||||
| r"""Subclass of nn.LayerNorm to handle fp16. | |||||
| """ | |||||
| def forward(self, x): | |||||
| return super(LayerNorm, self).forward(x.float()).type_as(x) | |||||
| class SelfAttention(nn.Module): | |||||
| def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): | |||||
| assert dim % num_heads == 0 | |||||
| super(SelfAttention, self).__init__() | |||||
| self.dim = dim | |||||
| self.num_heads = num_heads | |||||
| self.head_dim = dim // num_heads | |||||
| self.scale = 1.0 / math.sqrt(self.head_dim) | |||||
| # layers | |||||
| self.to_qkv = nn.Linear(dim, dim * 3) | |||||
| self.attn_dropout = nn.Dropout(attn_dropout) | |||||
| self.proj = nn.Linear(dim, dim) | |||||
| self.proj_dropout = nn.Dropout(proj_dropout) | |||||
| def forward(self, x, mask=None): | |||||
| r"""x: [B, L, C]. | |||||
| mask: [*, L, L]. | |||||
| """ | |||||
| b, l, _, n = *x.size(), self.num_heads | |||||
| # compute query, key, and value | |||||
| q, k, v = self.to_qkv(x.transpose(0, 1)).chunk(3, dim=-1) | |||||
| q = q.reshape(l, b * n, -1).transpose(0, 1) | |||||
| k = k.reshape(l, b * n, -1).transpose(0, 1) | |||||
| v = v.reshape(l, b * n, -1).transpose(0, 1) | |||||
| # compute attention | |||||
| attn = self.scale * torch.bmm(q, k.transpose(1, 2)) | |||||
| if mask is not None: | |||||
| attn = attn.masked_fill(mask[:, :l, :l] == 0, float('-inf')) | |||||
| attn = F.softmax(attn.float(), dim=-1).type_as(attn) | |||||
| attn = self.attn_dropout(attn) | |||||
| # gather context | |||||
| x = torch.bmm(attn, v) | |||||
| x = x.view(b, n, l, -1).transpose(1, 2).reshape(b, l, -1) | |||||
| # output | |||||
| x = self.proj(x) | |||||
| x = self.proj_dropout(x) | |||||
| return x | |||||
| class AttentionBlock(nn.Module): | |||||
| def __init__(self, dim, num_heads, attn_dropout=0.0, proj_dropout=0.0): | |||||
| super(AttentionBlock, self).__init__() | |||||
| self.dim = dim | |||||
| self.num_heads = num_heads | |||||
| # layers | |||||
| self.norm1 = LayerNorm(dim) | |||||
| self.attn = SelfAttention(dim, num_heads, attn_dropout, proj_dropout) | |||||
| self.norm2 = LayerNorm(dim) | |||||
| self.mlp = nn.Sequential( | |||||
| nn.Linear(dim, dim * 4), QuickGELU(), nn.Linear(dim * 4, dim), | |||||
| nn.Dropout(proj_dropout)) | |||||
| def forward(self, x, mask=None): | |||||
| x = x + self.attn(self.norm1(x), mask) | |||||
| x = x + self.mlp(self.norm2(x)) | |||||
| return x | |||||
| class VisionTransformer(nn.Module): | |||||
| def __init__(self, | |||||
| image_size=224, | |||||
| patch_size=16, | |||||
| dim=768, | |||||
| out_dim=512, | |||||
| num_heads=12, | |||||
| num_layers=12, | |||||
| attn_dropout=0.0, | |||||
| proj_dropout=0.0, | |||||
| embedding_dropout=0.0): | |||||
| assert image_size % patch_size == 0 | |||||
| super(VisionTransformer, self).__init__() | |||||
| self.image_size = image_size | |||||
| self.patch_size = patch_size | |||||
| self.dim = dim | |||||
| self.out_dim = out_dim | |||||
| self.num_heads = num_heads | |||||
| self.num_layers = num_layers | |||||
| self.num_patches = (image_size // patch_size)**2 | |||||
| # embeddings | |||||
| gain = 1.0 / math.sqrt(dim) | |||||
| self.patch_embedding = nn.Conv2d( | |||||
| 3, dim, kernel_size=patch_size, stride=patch_size, bias=False) | |||||
| self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) | |||||
| self.pos_embedding = nn.Parameter( | |||||
| gain * torch.randn(1, self.num_patches + 1, dim)) | |||||
| self.dropout = nn.Dropout(embedding_dropout) | |||||
| # transformer | |||||
| self.pre_norm = LayerNorm(dim) | |||||
| self.transformer = nn.Sequential(*[ | |||||
| AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) | |||||
| for _ in range(num_layers) | |||||
| ]) | |||||
| self.post_norm = LayerNorm(dim) | |||||
| # head | |||||
| self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) | |||||
| def forward(self, x): | |||||
| b, dtype = x.size(0), self.head.dtype | |||||
| x = x.type(dtype) | |||||
| # patch-embedding | |||||
| x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) # [b, n, c] | |||||
| x = torch.cat([self.cls_embedding.repeat(b, 1, 1).type(dtype), x], | |||||
| dim=1) | |||||
| x = self.dropout(x + self.pos_embedding.type(dtype)) | |||||
| x = self.pre_norm(x) | |||||
| # transformer | |||||
| x = self.transformer(x) | |||||
| # head | |||||
| x = self.post_norm(x) | |||||
| x = torch.mm(x[:, 0, :], self.head) | |||||
| return x | |||||
| def fp16(self): | |||||
| return self.apply(to_fp16) | |||||
| class TextTransformer(nn.Module): | |||||
| def __init__(self, | |||||
| vocab_size, | |||||
| text_len, | |||||
| dim=512, | |||||
| out_dim=512, | |||||
| num_heads=8, | |||||
| num_layers=12, | |||||
| attn_dropout=0.0, | |||||
| proj_dropout=0.0, | |||||
| embedding_dropout=0.0): | |||||
| super(TextTransformer, self).__init__() | |||||
| self.vocab_size = vocab_size | |||||
| self.text_len = text_len | |||||
| self.dim = dim | |||||
| self.out_dim = out_dim | |||||
| self.num_heads = num_heads | |||||
| self.num_layers = num_layers | |||||
| # embeddings | |||||
| self.token_embedding = nn.Embedding(vocab_size, dim) | |||||
| self.pos_embedding = nn.Parameter(0.01 * torch.randn(1, text_len, dim)) | |||||
| self.dropout = nn.Dropout(embedding_dropout) | |||||
| # transformer | |||||
| self.transformer = nn.ModuleList([ | |||||
| AttentionBlock(dim, num_heads, attn_dropout, proj_dropout) | |||||
| for _ in range(num_layers) | |||||
| ]) | |||||
| self.norm = LayerNorm(dim) | |||||
| # head | |||||
| gain = 1.0 / math.sqrt(dim) | |||||
| self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) | |||||
| # causal attention mask | |||||
| self.register_buffer('attn_mask', | |||||
| torch.tril(torch.ones(1, text_len, text_len))) | |||||
| def forward(self, x): | |||||
| eot, dtype = x.argmax(dim=-1), self.head.dtype | |||||
| # embeddings | |||||
| x = self.dropout( | |||||
| self.token_embedding(x).type(dtype) | |||||
| + self.pos_embedding.type(dtype)) | |||||
| # transformer | |||||
| for block in self.transformer: | |||||
| x = block(x, self.attn_mask) | |||||
| # head | |||||
| x = self.norm(x) | |||||
| x = torch.mm(x[torch.arange(x.size(0)), eot], self.head) | |||||
| return x | |||||
| def fp16(self): | |||||
| return self.apply(to_fp16) | |||||
| class CLIP(nn.Module): | |||||
| def __init__(self, | |||||
| embed_dim=512, | |||||
| image_size=224, | |||||
| patch_size=16, | |||||
| vision_dim=768, | |||||
| vision_heads=12, | |||||
| vision_layers=12, | |||||
| vocab_size=49408, | |||||
| text_len=77, | |||||
| text_dim=512, | |||||
| text_heads=8, | |||||
| text_layers=12, | |||||
| attn_dropout=0.0, | |||||
| proj_dropout=0.0, | |||||
| embedding_dropout=0.0): | |||||
| super(CLIP, self).__init__() | |||||
| self.embed_dim = embed_dim | |||||
| self.image_size = image_size | |||||
| self.patch_size = patch_size | |||||
| self.vision_dim = vision_dim | |||||
| self.vision_heads = vision_heads | |||||
| self.vision_layers = vision_layers | |||||
| self.vocab_size = vocab_size | |||||
| self.text_len = text_len | |||||
| self.text_dim = text_dim | |||||
| self.text_heads = text_heads | |||||
| self.text_layers = text_layers | |||||
| # models | |||||
| self.visual = VisionTransformer( | |||||
| image_size=image_size, | |||||
| patch_size=patch_size, | |||||
| dim=vision_dim, | |||||
| out_dim=embed_dim, | |||||
| num_heads=vision_heads, | |||||
| num_layers=vision_layers, | |||||
| attn_dropout=attn_dropout, | |||||
| proj_dropout=proj_dropout, | |||||
| embedding_dropout=embedding_dropout) | |||||
| self.textual = TextTransformer( | |||||
| vocab_size=vocab_size, | |||||
| text_len=text_len, | |||||
| dim=text_dim, | |||||
| out_dim=embed_dim, | |||||
| num_heads=text_heads, | |||||
| num_layers=text_layers, | |||||
| attn_dropout=attn_dropout, | |||||
| proj_dropout=proj_dropout, | |||||
| embedding_dropout=embedding_dropout) | |||||
| self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) | |||||
| def forward(self, imgs, txt_tokens): | |||||
| r"""imgs: [B, C, H, W] of torch.float32. | |||||
| txt_tokens: [B, T] of torch.long. | |||||
| """ | |||||
| xi = self.visual(imgs) | |||||
| xt = self.textual(txt_tokens) | |||||
| # normalize features | |||||
| xi = F.normalize(xi, p=2, dim=1) | |||||
| xt = F.normalize(xt, p=2, dim=1) | |||||
| # gather features from all ranks | |||||
| full_xi = ops.diff_all_gather(xi) | |||||
| full_xt = ops.diff_all_gather(xt) | |||||
| # logits | |||||
| scale = self.log_scale.exp() | |||||
| logits_i2t = scale * torch.mm(xi, full_xt.t()) | |||||
| logits_t2i = scale * torch.mm(xt, full_xi.t()) | |||||
| # labels | |||||
| labels = torch.arange( | |||||
| len(xi) * ops.get_rank(), | |||||
| len(xi) * (ops.get_rank() + 1), | |||||
| dtype=torch.long, | |||||
| device=xi.device) | |||||
| return logits_i2t, logits_t2i, labels | |||||
| def init_weights(self): | |||||
| # embeddings | |||||
| nn.init.normal_(self.textual.token_embedding.weight, std=0.02) | |||||
| nn.init.normal_(self.visual.patch_embedding.weight, tsd=0.1) | |||||
| # attentions | |||||
| for modality in ['visual', 'textual']: | |||||
| dim = self.vision_dim if modality == 'visual' else 'textual' | |||||
| transformer = getattr(self, modality).transformer | |||||
| proj_gain = (1.0 / math.sqrt(dim)) * ( | |||||
| 1.0 / math.sqrt(2 * transformer.num_layers)) | |||||
| attn_gain = 1.0 / math.sqrt(dim) | |||||
| mlp_gain = 1.0 / math.sqrt(2.0 * dim) | |||||
| for block in transformer.layers: | |||||
| nn.init.normal_(block.attn.to_qkv.weight, std=attn_gain) | |||||
| nn.init.normal_(block.attn.proj.weight, std=proj_gain) | |||||
| nn.init.normal_(block.mlp[0].weight, std=mlp_gain) | |||||
| nn.init.normal_(block.mlp[2].weight, std=proj_gain) | |||||
| def param_groups(self): | |||||
| groups = [{ | |||||
| 'params': [ | |||||
| p for n, p in self.named_parameters() | |||||
| if 'norm' in n or n.endswith('bias') | |||||
| ], | |||||
| 'weight_decay': | |||||
| 0.0 | |||||
| }, { | |||||
| 'params': [ | |||||
| p for n, p in self.named_parameters() | |||||
| if not ('norm' in n or n.endswith('bias')) | |||||
| ] | |||||
| }] | |||||
| return groups | |||||
| def fp16(self): | |||||
| return self.apply(to_fp16) | |||||
| def clip_vit_b_32(**kwargs): | |||||
| return CLIP( | |||||
| embed_dim=512, | |||||
| image_size=224, | |||||
| patch_size=32, | |||||
| vision_dim=768, | |||||
| vision_heads=12, | |||||
| vision_layers=12, | |||||
| text_dim=512, | |||||
| text_heads=8, | |||||
| text_layers=12, | |||||
| **kwargs) | |||||
| def clip_vit_b_16(**kwargs): | |||||
| return CLIP( | |||||
| embed_dim=512, | |||||
| image_size=224, | |||||
| patch_size=16, | |||||
| vision_dim=768, | |||||
| vision_heads=12, | |||||
| vision_layers=12, | |||||
| text_dim=512, | |||||
| text_heads=8, | |||||
| text_layers=12, | |||||
| **kwargs) | |||||
| def clip_vit_l_14(**kwargs): | |||||
| return CLIP( | |||||
| embed_dim=768, | |||||
| image_size=224, | |||||
| patch_size=14, | |||||
| vision_dim=1024, | |||||
| vision_heads=16, | |||||
| vision_layers=24, | |||||
| text_dim=768, | |||||
| text_heads=12, | |||||
| text_layers=12, | |||||
| **kwargs) | |||||
| def clip_vit_l_14_336px(**kwargs): | |||||
| return CLIP( | |||||
| embed_dim=768, | |||||
| image_size=336, | |||||
| patch_size=14, | |||||
| vision_dim=1024, | |||||
| vision_heads=16, | |||||
| vision_layers=24, | |||||
| text_dim=768, | |||||
| text_heads=12, | |||||
| text_layers=12, | |||||
| **kwargs) | |||||
| def clip_vit_h_16(**kwargs): | |||||
| return CLIP( | |||||
| embed_dim=1024, | |||||
| image_size=256, | |||||
| patch_size=16, | |||||
| vision_dim=1280, | |||||
| vision_heads=16, | |||||
| vision_layers=32, | |||||
| text_dim=1024, | |||||
| text_heads=16, | |||||
| text_layers=24, | |||||
| **kwargs) | |||||
| @@ -0,0 +1,22 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .diffusion import GaussianDiffusion, beta_schedule | |||||
| else: | |||||
| _import_structure = { | |||||
| 'diffusion': ['GaussianDiffusion', 'beta_schedule'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,598 @@ | |||||
| import math | |||||
| import torch | |||||
| from .losses import discretized_gaussian_log_likelihood, kl_divergence | |||||
| __all__ = ['GaussianDiffusion', 'beta_schedule'] | |||||
| def _i(tensor, t, x): | |||||
| r"""Index tensor using t and format the output according to x. | |||||
| """ | |||||
| shape = (x.size(0), ) + (1, ) * (x.ndim - 1) | |||||
| return tensor[t].view(shape).to(x) | |||||
| def beta_schedule(schedule, | |||||
| num_timesteps=1000, | |||||
| init_beta=None, | |||||
| last_beta=None): | |||||
| if schedule == 'linear': | |||||
| scale = 1000.0 / num_timesteps | |||||
| init_beta = init_beta or scale * 0.0001 | |||||
| last_beta = last_beta or scale * 0.02 | |||||
| return torch.linspace( | |||||
| init_beta, last_beta, num_timesteps, dtype=torch.float64) | |||||
| elif schedule == 'quadratic': | |||||
| init_beta = init_beta or 0.0015 | |||||
| last_beta = last_beta or 0.0195 | |||||
| return torch.linspace( | |||||
| init_beta**0.5, last_beta**0.5, num_timesteps, | |||||
| dtype=torch.float64)**2 | |||||
| elif schedule == 'cosine': | |||||
| betas = [] | |||||
| for step in range(num_timesteps): | |||||
| t1 = step / num_timesteps | |||||
| t2 = (step + 1) / num_timesteps | |||||
| # fn = lambda u: math.cos((u + 0.008) / 1.008 * math.pi / 2)**2 | |||||
| def fn(u): | |||||
| return math.cos((u + 0.008) / 1.008 * math.pi / 2)**2 | |||||
| betas.append(min(1.0 - fn(t2) / fn(t1), 0.999)) | |||||
| return torch.tensor(betas, dtype=torch.float64) | |||||
| else: | |||||
| raise ValueError(f'Unsupported schedule: {schedule}') | |||||
| class GaussianDiffusion(object): | |||||
| def __init__(self, | |||||
| betas, | |||||
| mean_type='eps', | |||||
| var_type='learned_range', | |||||
| loss_type='mse', | |||||
| rescale_timesteps=False): | |||||
| # check input | |||||
| if not isinstance(betas, torch.DoubleTensor): | |||||
| betas = torch.tensor(betas, dtype=torch.float64) | |||||
| assert min(betas) > 0 and max(betas) <= 1 | |||||
| assert mean_type in ['x0', 'x_{t-1}', 'eps'] | |||||
| assert var_type in [ | |||||
| 'learned', 'learned_range', 'fixed_large', 'fixed_small' | |||||
| ] | |||||
| assert loss_type in [ | |||||
| 'mse', 'rescaled_mse', 'kl', 'rescaled_kl', 'l1', 'rescaled_l1' | |||||
| ] | |||||
| self.betas = betas | |||||
| self.num_timesteps = len(betas) | |||||
| self.mean_type = mean_type | |||||
| self.var_type = var_type | |||||
| self.loss_type = loss_type | |||||
| self.rescale_timesteps = rescale_timesteps | |||||
| # alphas | |||||
| alphas = 1 - self.betas | |||||
| self.alphas_cumprod = torch.cumprod(alphas, dim=0) | |||||
| self.alphas_cumprod_prev = torch.cat( | |||||
| [alphas.new_ones([1]), self.alphas_cumprod[:-1]]) | |||||
| self.alphas_cumprod_next = torch.cat( | |||||
| [self.alphas_cumprod[1:], | |||||
| alphas.new_zeros([1])]) | |||||
| # q(x_t | x_{t-1}) | |||||
| self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) | |||||
| self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 | |||||
| - self.alphas_cumprod) | |||||
| self.log_one_minus_alphas_cumprod = torch.log(1.0 | |||||
| - self.alphas_cumprod) | |||||
| self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) | |||||
| self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod | |||||
| - 1) | |||||
| # q(x_{t-1} | x_t, x_0) | |||||
| self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / ( | |||||
| 1.0 - self.alphas_cumprod) | |||||
| self.posterior_log_variance_clipped = torch.log( | |||||
| self.posterior_variance.clamp(1e-20)) | |||||
| self.posterior_mean_coef1 = betas * torch.sqrt( | |||||
| self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) | |||||
| self.posterior_mean_coef2 = ( | |||||
| 1.0 - self.alphas_cumprod_prev) * torch.sqrt(alphas) / ( | |||||
| 1.0 - self.alphas_cumprod) | |||||
| def q_sample(self, x0, t, noise=None): | |||||
| r"""Sample from q(x_t | x_0). | |||||
| """ | |||||
| noise = torch.randn_like(x0) if noise is None else noise | |||||
| return _i(self.sqrt_alphas_cumprod, t, x0) * x0 + _i( | |||||
| self.sqrt_one_minus_alphas_cumprod, t, x0) * noise | |||||
| def q_mean_variance(self, x0, t): | |||||
| r"""Distribution of q(x_t | x_0). | |||||
| """ | |||||
| mu = _i(self.sqrt_alphas_cumprod, t, x0) * x0 | |||||
| var = _i(1.0 - self.alphas_cumprod, t, x0) | |||||
| log_var = _i(self.log_one_minus_alphas_cumprod, t, x0) | |||||
| return mu, var, log_var | |||||
| def q_posterior_mean_variance(self, x0, xt, t): | |||||
| r"""Distribution of q(x_{t-1} | x_t, x_0). | |||||
| """ | |||||
| mu = _i(self.posterior_mean_coef1, t, xt) * x0 + _i( | |||||
| self.posterior_mean_coef2, t, xt) * xt | |||||
| var = _i(self.posterior_variance, t, xt) | |||||
| log_var = _i(self.posterior_log_variance_clipped, t, xt) | |||||
| return mu, var, log_var | |||||
| @torch.no_grad() | |||||
| def p_sample(self, | |||||
| xt, | |||||
| t, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| condition_fn=None, | |||||
| guide_scale=None): | |||||
| r"""Sample from p(x_{t-1} | x_t). | |||||
| - condition_fn: for classifier-based guidance (guided-diffusion). | |||||
| - guide_scale: for classifier-free guidance (glide/dalle-2). | |||||
| """ | |||||
| # predict distribution of p(x_{t-1} | x_t) | |||||
| mu, var, log_var, x0 = self.p_mean_variance(xt, t, model, model_kwargs, | |||||
| clamp, percentile, | |||||
| guide_scale) | |||||
| # random sample (with optional conditional function) | |||||
| noise = torch.randn_like(xt) | |||||
| # no noise when t == 0 | |||||
| mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) | |||||
| if condition_fn is not None: | |||||
| grad = condition_fn(xt, self._scale_timesteps(t), **model_kwargs) | |||||
| mu = mu.float() + var * grad.float() | |||||
| xt_1 = mu + mask * torch.exp(0.5 * log_var) * noise | |||||
| return xt_1, x0 | |||||
| @torch.no_grad() | |||||
| def p_sample_loop(self, | |||||
| noise, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| condition_fn=None, | |||||
| guide_scale=None): | |||||
| r"""Sample from p(x_{t-1} | x_t) p(x_{t-2} | x_{t-1}) ... p(x_0 | x_1). | |||||
| """ | |||||
| # prepare input | |||||
| b, c, h, w = noise.size() | |||||
| xt = noise | |||||
| # diffusion process | |||||
| for step in torch.arange(self.num_timesteps).flip(0): | |||||
| t = torch.full((b, ), step, dtype=torch.long, device=xt.device) | |||||
| xt, _ = self.p_sample(xt, t, model, model_kwargs, clamp, | |||||
| percentile, condition_fn, guide_scale) | |||||
| return xt | |||||
| def p_mean_variance(self, | |||||
| xt, | |||||
| t, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| guide_scale=None): | |||||
| r"""Distribution of p(x_{t-1} | x_t). | |||||
| """ | |||||
| # predict distribution | |||||
| if guide_scale is None: | |||||
| out = model(xt, self._scale_timesteps(t), **model_kwargs) | |||||
| else: | |||||
| # classifier-free guidance | |||||
| # (model_kwargs[0]: conditional kwargs; model_kwargs[1]: non-conditional kwargs) | |||||
| assert isinstance(model_kwargs, list) and len(model_kwargs) == 2 | |||||
| assert self.mean_type == 'eps' | |||||
| y_out = model(xt, self._scale_timesteps(t), **model_kwargs[0]) | |||||
| u_out = model(xt, self._scale_timesteps(t), **model_kwargs[1]) | |||||
| out = torch.cat( | |||||
| [ | |||||
| u_out[:, :3] + guide_scale * # noqa W504 | |||||
| (y_out[:, :3] - u_out[:, :3]), | |||||
| y_out[:, 3:] | |||||
| ], | |||||
| dim=1) | |||||
| # compute variance | |||||
| if self.var_type == 'learned': | |||||
| out, log_var = out.chunk(2, dim=1) | |||||
| var = torch.exp(log_var) | |||||
| elif self.var_type == 'learned_range': | |||||
| out, fraction = out.chunk(2, dim=1) | |||||
| min_log_var = _i(self.posterior_log_variance_clipped, t, xt) | |||||
| max_log_var = _i(torch.log(self.betas), t, xt) | |||||
| fraction = (fraction + 1) / 2.0 | |||||
| log_var = fraction * max_log_var + (1 - fraction) * min_log_var | |||||
| var = torch.exp(log_var) | |||||
| elif self.var_type == 'fixed_large': | |||||
| var = _i( | |||||
| torch.cat([self.posterior_variance[1:2], self.betas[1:]]), t, | |||||
| xt) | |||||
| log_var = torch.log(var) | |||||
| elif self.var_type == 'fixed_small': | |||||
| var = _i(self.posterior_variance, t, xt) | |||||
| log_var = _i(self.posterior_log_variance_clipped, t, xt) | |||||
| # compute mean and x0 | |||||
| if self.mean_type == 'x_{t-1}': | |||||
| mu = out # x_{t-1} | |||||
| x0 = _i(1.0 / self.posterior_mean_coef1, t, xt) * mu - _i( | |||||
| self.posterior_mean_coef2 / self.posterior_mean_coef1, t, | |||||
| xt) * xt | |||||
| elif self.mean_type == 'x0': | |||||
| x0 = out | |||||
| mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) | |||||
| elif self.mean_type == 'eps': | |||||
| x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( | |||||
| self.sqrt_recipm1_alphas_cumprod, t, xt) * out | |||||
| mu, _, _ = self.q_posterior_mean_variance(x0, xt, t) | |||||
| # restrict the range of x0 | |||||
| if percentile is not None: | |||||
| assert percentile > 0 and percentile <= 1 # e.g., 0.995 | |||||
| s = torch.quantile( | |||||
| x0.flatten(1).abs(), percentile, | |||||
| dim=1).clamp_(1.0).view(-1, 1, 1, 1) | |||||
| x0 = torch.min(s, torch.max(-s, x0)) / s | |||||
| elif clamp is not None: | |||||
| x0 = x0.clamp(-clamp, clamp) | |||||
| return mu, var, log_var, x0 | |||||
| @torch.no_grad() | |||||
| def ddim_sample(self, | |||||
| xt, | |||||
| t, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| condition_fn=None, | |||||
| guide_scale=None, | |||||
| ddim_timesteps=20, | |||||
| eta=0.0): | |||||
| stride = self.num_timesteps // ddim_timesteps | |||||
| # predict distribution of p(x_{t-1} | x_t) | |||||
| _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, | |||||
| percentile, guide_scale) | |||||
| if condition_fn is not None: | |||||
| # x0 -> eps | |||||
| alpha = _i(self.alphas_cumprod, t, xt) | |||||
| eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( | |||||
| self.sqrt_recipm1_alphas_cumprod, t, xt) | |||||
| eps = eps - (1 - alpha).sqrt() * condition_fn( | |||||
| xt, self._scale_timesteps(t), **model_kwargs) | |||||
| # eps -> x0 | |||||
| x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( | |||||
| self.sqrt_recipm1_alphas_cumprod, t, xt) * eps | |||||
| # derive variables | |||||
| eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( | |||||
| self.sqrt_recipm1_alphas_cumprod, t, xt) | |||||
| alphas = _i(self.alphas_cumprod, t, xt) | |||||
| alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) | |||||
| sigmas = eta * torch.sqrt((1 - alphas_prev) / # noqa W504 | |||||
| (1 - alphas) * # noqa W504 | |||||
| (1 - alphas / alphas_prev)) | |||||
| # random sample | |||||
| noise = torch.randn_like(xt) | |||||
| direction = torch.sqrt(1 - alphas_prev - sigmas**2) * eps | |||||
| mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) | |||||
| xt_1 = torch.sqrt(alphas_prev) * x0 + direction + mask * sigmas * noise | |||||
| return xt_1, x0 | |||||
| @torch.no_grad() | |||||
| def ddim_sample_loop(self, | |||||
| noise, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| condition_fn=None, | |||||
| guide_scale=None, | |||||
| ddim_timesteps=20, | |||||
| eta=0.0): | |||||
| # prepare input | |||||
| b, c, h, w = noise.size() | |||||
| xt = noise | |||||
| # diffusion process (TODO: clamp is inaccurate! Consider replacing the stride by explicit prev/next steps) | |||||
| steps = (1 + torch.arange(0, self.num_timesteps, | |||||
| self.num_timesteps // ddim_timesteps)).clamp( | |||||
| 0, self.num_timesteps - 1).flip(0) | |||||
| for step in steps: | |||||
| t = torch.full((b, ), step, dtype=torch.long, device=xt.device) | |||||
| xt, _ = self.ddim_sample(xt, t, model, model_kwargs, clamp, | |||||
| percentile, condition_fn, guide_scale, | |||||
| ddim_timesteps, eta) | |||||
| return xt | |||||
| @torch.no_grad() | |||||
| def ddim_reverse_sample(self, | |||||
| xt, | |||||
| t, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| guide_scale=None, | |||||
| ddim_timesteps=20): | |||||
| r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic). | |||||
| """ | |||||
| stride = self.num_timesteps // ddim_timesteps | |||||
| # predict distribution of p(x_{t-1} | x_t) | |||||
| _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, clamp, | |||||
| percentile, guide_scale) | |||||
| # derive variables | |||||
| eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( | |||||
| self.sqrt_recipm1_alphas_cumprod, t, xt) | |||||
| alphas_next = _i( | |||||
| torch.cat( | |||||
| [self.alphas_cumprod, | |||||
| self.alphas_cumprod.new_zeros([1])]), | |||||
| (t + stride).clamp(0, self.num_timesteps), xt) | |||||
| # reverse sample | |||||
| mu = torch.sqrt(alphas_next) * x0 + torch.sqrt(1 - alphas_next) * eps | |||||
| return mu, x0 | |||||
| @torch.no_grad() | |||||
| def ddim_reverse_sample_loop(self, | |||||
| x0, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| guide_scale=None, | |||||
| ddim_timesteps=20): | |||||
| # prepare input | |||||
| b, c, h, w = x0.size() | |||||
| xt = x0 | |||||
| # reconstruction steps | |||||
| steps = torch.arange(0, self.num_timesteps, | |||||
| self.num_timesteps // ddim_timesteps) | |||||
| for step in steps: | |||||
| t = torch.full((b, ), step, dtype=torch.long, device=xt.device) | |||||
| xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, | |||||
| percentile, guide_scale, | |||||
| ddim_timesteps) | |||||
| return xt | |||||
| @torch.no_grad() | |||||
| def plms_sample(self, | |||||
| xt, | |||||
| t, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| condition_fn=None, | |||||
| guide_scale=None, | |||||
| plms_timesteps=20): | |||||
| r"""Sample from p(x_{t-1} | x_t) using PLMS. | |||||
| - condition_fn: for classifier-based guidance (guided-diffusion). | |||||
| - guide_scale: for classifier-free guidance (glide/dalle-2). | |||||
| """ | |||||
| stride = self.num_timesteps // plms_timesteps | |||||
| # function for compute eps | |||||
| def compute_eps(xt, t): | |||||
| # predict distribution of p(x_{t-1} | x_t) | |||||
| _, _, _, x0 = self.p_mean_variance(xt, t, model, model_kwargs, | |||||
| clamp, percentile, guide_scale) | |||||
| # condition | |||||
| if condition_fn is not None: | |||||
| # x0 -> eps | |||||
| alpha = _i(self.alphas_cumprod, t, xt) | |||||
| eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt | |||||
| - x0) / _i(self.sqrt_recipm1_alphas_cumprod, t, xt) | |||||
| eps = eps - (1 - alpha).sqrt() * condition_fn( | |||||
| xt, self._scale_timesteps(t), **model_kwargs) | |||||
| # eps -> x0 | |||||
| x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( | |||||
| self.sqrt_recipm1_alphas_cumprod, t, xt) * eps | |||||
| # derive eps | |||||
| eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( | |||||
| self.sqrt_recipm1_alphas_cumprod, t, xt) | |||||
| return eps | |||||
| # function for compute x_0 and x_{t-1} | |||||
| def compute_x0(eps, t): | |||||
| # eps -> x0 | |||||
| x0 = _i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - _i( | |||||
| self.sqrt_recipm1_alphas_cumprod, t, xt) * eps | |||||
| # deterministic sample | |||||
| alphas_prev = _i(self.alphas_cumprod, (t - stride).clamp(0), xt) | |||||
| direction = torch.sqrt(1 - alphas_prev) * eps | |||||
| # mask = t.ne(0).float().view(-1, *((1, ) * (xt.ndim - 1))) | |||||
| xt_1 = torch.sqrt(alphas_prev) * x0 + direction | |||||
| return xt_1, x0 | |||||
| # PLMS sample | |||||
| eps = compute_eps(xt, t) | |||||
| if len(eps_cache) == 0: | |||||
| # 2nd order pseudo improved Euler | |||||
| xt_1, x0 = compute_x0(eps, t) | |||||
| eps_next = compute_eps(xt_1, (t - stride).clamp(0)) | |||||
| eps_prime = (eps + eps_next) / 2.0 | |||||
| elif len(eps_cache) == 1: | |||||
| # 2nd order pseudo linear multistep (Adams-Bashforth) | |||||
| eps_prime = (3 * eps - eps_cache[-1]) / 2.0 | |||||
| elif len(eps_cache) == 2: | |||||
| # 3nd order pseudo linear multistep (Adams-Bashforth) | |||||
| eps_prime = (23 * eps - 16 * eps_cache[-1] | |||||
| + 5 * eps_cache[-2]) / 12.0 | |||||
| elif len(eps_cache) >= 3: | |||||
| # 4nd order pseudo linear multistep (Adams-Bashforth) | |||||
| eps_prime = (55 * eps - 59 * eps_cache[-1] + 37 * eps_cache[-2] | |||||
| - 9 * eps_cache[-3]) / 24.0 | |||||
| xt_1, x0 = compute_x0(eps_prime, t) | |||||
| return xt_1, x0, eps | |||||
| @torch.no_grad() | |||||
| def plms_sample_loop(self, | |||||
| noise, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| condition_fn=None, | |||||
| guide_scale=None, | |||||
| plms_timesteps=20): | |||||
| # prepare input | |||||
| b, c, h, w = noise.size() | |||||
| xt = noise | |||||
| # diffusion process | |||||
| steps = (1 + torch.arange(0, self.num_timesteps, | |||||
| self.num_timesteps // plms_timesteps)).clamp( | |||||
| 0, self.num_timesteps - 1).flip(0) | |||||
| eps_cache = [] | |||||
| for step in steps: | |||||
| # PLMS sampling step | |||||
| t = torch.full((b, ), step, dtype=torch.long, device=xt.device) | |||||
| xt, _, eps = self.plms_sample(xt, t, model, model_kwargs, clamp, | |||||
| percentile, condition_fn, | |||||
| guide_scale, plms_timesteps, | |||||
| eps_cache) | |||||
| # update eps cache | |||||
| eps_cache.append(eps) | |||||
| if len(eps_cache) >= 4: | |||||
| eps_cache.pop(0) | |||||
| return xt | |||||
| def loss(self, x0, t, model, model_kwargs={}, noise=None): | |||||
| noise = torch.randn_like(x0) if noise is None else noise | |||||
| xt = self.q_sample(x0, t, noise=noise) | |||||
| # compute loss | |||||
| if self.loss_type in ['kl', 'rescaled_kl']: | |||||
| loss, _ = self.variational_lower_bound(x0, xt, t, model, | |||||
| model_kwargs) | |||||
| if self.loss_type == 'rescaled_kl': | |||||
| loss = loss * self.num_timesteps | |||||
| elif self.loss_type in ['mse', 'rescaled_mse', 'l1', 'rescaled_l1']: | |||||
| out = model(xt, self._scale_timesteps(t), **model_kwargs) | |||||
| # VLB for variation | |||||
| loss_vlb = 0.0 | |||||
| if self.var_type in ['learned', 'learned_range']: | |||||
| out, var = out.chunk(2, dim=1) | |||||
| frozen = torch.cat([ | |||||
| out.detach(), var | |||||
| ], dim=1) # learn var without affecting the prediction of mean | |||||
| loss_vlb, _ = self.variational_lower_bound( | |||||
| x0, xt, t, model=lambda *args, **kwargs: frozen) | |||||
| if self.loss_type.startswith('rescaled_'): | |||||
| loss_vlb = loss_vlb * self.num_timesteps / 1000.0 | |||||
| # MSE/L1 for x0/eps | |||||
| target = { | |||||
| 'eps': noise, | |||||
| 'x0': x0, | |||||
| 'x_{t-1}': self.q_posterior_mean_variance(x0, xt, t)[0] | |||||
| }[self.mean_type] | |||||
| loss = (out - target).pow(1 if self.loss_type.endswith('l1') else 2 | |||||
| ).abs().flatten(1).mean(dim=1) | |||||
| # total loss | |||||
| loss = loss + loss_vlb | |||||
| return loss | |||||
| def variational_lower_bound(self, | |||||
| x0, | |||||
| xt, | |||||
| t, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None): | |||||
| # compute groundtruth and predicted distributions | |||||
| mu1, _, log_var1 = self.q_posterior_mean_variance(x0, xt, t) | |||||
| mu2, _, log_var2, x0 = self.p_mean_variance(xt, t, model, model_kwargs, | |||||
| clamp, percentile) | |||||
| # compute KL loss | |||||
| kl = kl_divergence(mu1, log_var1, mu2, log_var2) | |||||
| kl = kl.flatten(1).mean(dim=1) / math.log(2.0) | |||||
| # compute discretized NLL loss (for p(x0 | x1) only) | |||||
| nll = -discretized_gaussian_log_likelihood( | |||||
| x0, mean=mu2, log_scale=0.5 * log_var2) | |||||
| nll = nll.flatten(1).mean(dim=1) / math.log(2.0) | |||||
| # NLL for p(x0 | x1) and KL otherwise | |||||
| vlb = torch.where(t == 0, nll, kl) | |||||
| return vlb, x0 | |||||
| @torch.no_grad() | |||||
| def variational_lower_bound_loop(self, | |||||
| x0, | |||||
| model, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None): | |||||
| r"""Compute the entire variational lower bound, measured in bits-per-dim. | |||||
| """ | |||||
| # prepare input and output | |||||
| b, c, h, w = x0.size() | |||||
| metrics = {'vlb': [], 'mse': [], 'x0_mse': []} | |||||
| # loop | |||||
| for step in torch.arange(self.num_timesteps).flip(0): | |||||
| # compute VLB | |||||
| t = torch.full((b, ), step, dtype=torch.long, device=x0.device) | |||||
| noise = torch.randn_like(x0) | |||||
| xt = self.q_sample(x0, t, noise) | |||||
| vlb, pred_x0 = self.variational_lower_bound( | |||||
| x0, xt, t, model, model_kwargs, clamp, percentile) | |||||
| # predict eps from x0 | |||||
| eps = (_i(self.sqrt_recip_alphas_cumprod, t, xt) * xt - x0) / _i( | |||||
| self.sqrt_recipm1_alphas_cumprod, t, xt) | |||||
| # collect metrics | |||||
| metrics['vlb'].append(vlb) | |||||
| metrics['x0_mse'].append( | |||||
| (pred_x0 - x0).square().flatten(1).mean(dim=1)) | |||||
| metrics['mse'].append( | |||||
| (eps - noise).square().flatten(1).mean(dim=1)) | |||||
| metrics = {k: torch.stack(v, dim=1) for k, v in metrics.items()} | |||||
| # compute the prior KL term for VLB, measured in bits-per-dim | |||||
| mu, _, log_var = self.q_mean_variance(x0, t) | |||||
| kl_prior = kl_divergence(mu, log_var, torch.zeros_like(mu), | |||||
| torch.zeros_like(log_var)) | |||||
| kl_prior = kl_prior.flatten(1).mean(dim=1) / math.log(2.0) | |||||
| # update metrics | |||||
| metrics['prior_bits_per_dim'] = kl_prior | |||||
| metrics['total_bits_per_dim'] = metrics['vlb'].sum(dim=1) + kl_prior | |||||
| return metrics | |||||
| def _scale_timesteps(self, t): | |||||
| if self.rescale_timesteps: | |||||
| return t.float() * 1000.0 / self.num_timesteps | |||||
| return t | |||||
| @@ -0,0 +1,35 @@ | |||||
| import math | |||||
| import torch | |||||
| __all__ = ['kl_divergence', 'discretized_gaussian_log_likelihood'] | |||||
| def kl_divergence(mu1, logvar1, mu2, logvar2): | |||||
| return 0.5 * ( | |||||
| -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + # noqa W504 | |||||
| ((mu1 - mu2)**2) * torch.exp(-logvar2)) | |||||
| def standard_normal_cdf(x): | |||||
| r"""A fast approximation of the cumulative distribution function of the standard normal. | |||||
| """ | |||||
| return 0.5 * (1.0 + torch.tanh( | |||||
| math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |||||
| def discretized_gaussian_log_likelihood(x0, mean, log_scale): | |||||
| assert x0.shape == mean.shape == log_scale.shape | |||||
| cx = x0 - mean | |||||
| inv_stdv = torch.exp(-log_scale) | |||||
| cdf_plus = standard_normal_cdf(inv_stdv * (cx + 1.0 / 255.0)) | |||||
| cdf_min = standard_normal_cdf(inv_stdv * (cx - 1.0 / 255.0)) | |||||
| log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12)) | |||||
| log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12)) | |||||
| cdf_delta = cdf_plus - cdf_min | |||||
| log_probs = torch.where( | |||||
| x0 < -0.999, log_cdf_plus, | |||||
| torch.where(x0 > 0.999, log_one_minus_cdf_min, | |||||
| torch.log(cdf_delta.clamp(min=1e-12)))) | |||||
| assert log_probs.shape == x0.shape | |||||
| return log_probs | |||||
| @@ -118,6 +118,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { | |||||
| Tasks.image_classification_dailylife: | Tasks.image_classification_dailylife: | ||||
| (Pipelines.daily_image_classification, | (Pipelines.daily_image_classification, | ||||
| 'damo/cv_vit-base_image-classification_Dailylife-labels'), | 'damo/cv_vit-base_image-classification_Dailylife-labels'), | ||||
| Tasks.image_to_image_generation: | |||||
| (Pipelines.image_to_image_generation, | |||||
| 'damo/cv_latent_diffusion_image2image_generate'), | |||||
| } | } | ||||
| @@ -20,6 +20,7 @@ if TYPE_CHECKING: | |||||
| from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline | from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline | ||||
| from .image_matting_pipeline import ImageMattingPipeline | from .image_matting_pipeline import ImageMattingPipeline | ||||
| from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | from .image_super_resolution_pipeline import ImageSuperResolutionPipeline | ||||
| from .image_to_image_generation_pipeline import Image2ImageGenerationePipeline | |||||
| from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | from .image_to_image_translation_pipeline import Image2ImageTranslationPipeline | ||||
| from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline | from .product_retrieval_embedding_pipeline import ProductRetrievalEmbeddingPipeline | ||||
| from .style_transfer_pipeline import StyleTransferPipeline | from .style_transfer_pipeline import StyleTransferPipeline | ||||
| @@ -51,6 +52,8 @@ else: | |||||
| 'product_retrieval_embedding_pipeline': | 'product_retrieval_embedding_pipeline': | ||||
| ['ProductRetrievalEmbeddingPipeline'], | ['ProductRetrievalEmbeddingPipeline'], | ||||
| 'live_category_pipeline': ['LiveCategoryPipeline'], | 'live_category_pipeline': ['LiveCategoryPipeline'], | ||||
| 'image_to_image_generation_pipeline': | |||||
| ['Image2ImageGenerationePipeline'], | |||||
| 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | 'ocr_detection_pipeline': ['OCRDetectionPipeline'], | ||||
| 'style_transfer_pipeline': ['StyleTransferPipeline'], | 'style_transfer_pipeline': ['StyleTransferPipeline'], | ||||
| 'video_category_pipeline': ['VideoCategoryPipeline'], | 'video_category_pipeline': ['VideoCategoryPipeline'], | ||||
| @@ -0,0 +1,250 @@ | |||||
| import os.path as osp | |||||
| from typing import Any, Dict | |||||
| import cv2 | |||||
| import numpy as np | |||||
| import PIL | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| import torchvision.transforms as T | |||||
| import torchvision.transforms.functional as TF | |||||
| from PIL import Image | |||||
| from torchvision.utils import save_image | |||||
| import modelscope.models.cv.image_to_image_generation.data as data | |||||
| import modelscope.models.cv.image_to_image_generation.models as models | |||||
| import modelscope.models.cv.image_to_image_generation.ops as ops | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.cv.image_to_image_generation.model import UNet | |||||
| from modelscope.models.cv.image_to_image_generation.models.clip import \ | |||||
| VisionTransformer | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Input, Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import LoadImage | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| @PIPELINES.register_module( | |||||
| Tasks.image_to_image_generation, | |||||
| module_name=Pipelines.image_to_image_generation) | |||||
| class Image2ImageGenerationePipeline(Pipeline): | |||||
| def __init__(self, model: str, **kwargs): | |||||
| """ | |||||
| use `model` to create a image-to-image generation pipeline | |||||
| Args: | |||||
| model: model id on modelscope hub. | |||||
| """ | |||||
| super().__init__(model=model) | |||||
| config_path = osp.join(self.model, ModelFile.CONFIGURATION) | |||||
| logger.info(f'loading config from {config_path}') | |||||
| self.cfg = Config.from_file(config_path) | |||||
| if torch.cuda.is_available(): | |||||
| self._device = torch.device('cuda') | |||||
| else: | |||||
| self._device = torch.device('cpu') | |||||
| self.repetition = 4 | |||||
| # load vit model | |||||
| vit_model_path = osp.join(self.model, | |||||
| self.cfg.ModelPath.vit_model_path) | |||||
| logger.info(f'loading vit model from {vit_model_path}') | |||||
| self.vit = VisionTransformer( | |||||
| image_size=self.cfg.Params.vit.vit_image_size, | |||||
| patch_size=self.cfg.Params.vit.vit_patch_size, | |||||
| dim=self.cfg.Params.vit.vit_dim, | |||||
| out_dim=self.cfg.Params.vit.vit_out_dim, | |||||
| num_heads=self.cfg.Params.vit.vit_num_heads, | |||||
| num_layers=self.cfg.Params.vit.vit_num_layers).eval( | |||||
| ).requires_grad_(False).to(self._device) # noqa E123 | |||||
| state = torch.load(vit_model_path) | |||||
| state = { | |||||
| k[len('visual.'):]: v | |||||
| for k, v in state.items() if k.startswith('visual.') | |||||
| } | |||||
| self.vit.load_state_dict(state) | |||||
| logger.info('load vit model done') | |||||
| # load autoencoder model | |||||
| ae_model_path = osp.join(self.model, self.cfg.ModelPath.ae_model_path) | |||||
| logger.info(f'loading autoencoder model from {ae_model_path}') | |||||
| self.autoencoder = models.VQAutoencoder( | |||||
| dim=self.cfg.Params.ae.ae_dim, | |||||
| z_dim=self.cfg.Params.ae.ae_z_dim, | |||||
| dim_mult=self.cfg.Params.ae.ae_dim_mult, | |||||
| attn_scales=self.cfg.Params.ae.ae_attn_scales, | |||||
| codebook_size=self.cfg.Params.ae.ae_codebook_size).eval( | |||||
| ).requires_grad_(False).to(self._device) # noqa E123 | |||||
| self.autoencoder.load_state_dict( | |||||
| torch.load(ae_model_path, map_location=self._device)) | |||||
| logger.info('load autoencoder model done') | |||||
| # load decoder model | |||||
| decoder_model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) | |||||
| logger.info(f'loading decoder model from {decoder_model_path}') | |||||
| self.decoder = UNet( | |||||
| resolution=self.cfg.Params.unet.unet_resolution, | |||||
| in_dim=self.cfg.Params.unet.unet_in_dim, | |||||
| dim=self.cfg.Params.unet.unet_dim, | |||||
| label_dim=self.cfg.Params.vit.vit_out_dim, | |||||
| context_dim=self.cfg.Params.unet.unet_context_dim, | |||||
| out_dim=self.cfg.Params.unet.unet_out_dim, | |||||
| dim_mult=self.cfg.Params.unet.unet_dim_mult, | |||||
| num_heads=self.cfg.Params.unet.unet_num_heads, | |||||
| head_dim=None, | |||||
| num_res_blocks=self.cfg.Params.unet.unet_res_blocks, | |||||
| attn_scales=self.cfg.Params.unet.unet_attn_scales, | |||||
| dropout=self.cfg.Params.unet.unet_dropout).eval().requires_grad_( | |||||
| False).to(self._device) | |||||
| self.decoder.load_state_dict( | |||||
| torch.load(decoder_model_path, map_location=self._device)) | |||||
| logger.info('load decoder model done') | |||||
| # diffusion | |||||
| logger.info('Initialization diffusion ...') | |||||
| betas = ops.beta_schedule(self.cfg.Params.diffusion.schedule, | |||||
| self.cfg.Params.diffusion.num_timesteps) | |||||
| self.diffusion = ops.GaussianDiffusion( | |||||
| betas=betas, | |||||
| mean_type=self.cfg.Params.diffusion.mean_type, | |||||
| var_type=self.cfg.Params.diffusion.var_type, | |||||
| loss_type=self.cfg.Params.diffusion.loss_type, | |||||
| rescale_timesteps=False) | |||||
| def preprocess(self, input: Input) -> Dict[str, Any]: | |||||
| input_img_list = [] | |||||
| if isinstance(input, str): | |||||
| input_img_list = [input] | |||||
| input_type = 0 | |||||
| elif isinstance(input, tuple) and len(input) == 2: | |||||
| input_img_list = list(input) | |||||
| input_type = 1 | |||||
| else: | |||||
| raise TypeError( | |||||
| 'modelscope error: Only support "str" or "tuple (img1, img2)" , but got {type(input)}' | |||||
| ) | |||||
| if input_type == 0: | |||||
| logger.info('Processing Similar Image Generation mode') | |||||
| if input_type == 1: | |||||
| logger.info('Processing Interpolation mode') | |||||
| img_list = [] | |||||
| for i, input_img in enumerate(input_img_list): | |||||
| img = LoadImage.convert_to_img(input_img) | |||||
| logger.info(f'Load {i}-th image done') | |||||
| img_list.append(img) | |||||
| transforms = T.Compose([ | |||||
| data.PadToSquare(), | |||||
| T.Resize( | |||||
| self.cfg.DATA.scale_size, | |||||
| interpolation=T.InterpolationMode.BICUBIC), | |||||
| T.ToTensor(), | |||||
| T.Normalize(mean=self.cfg.DATA.mean, std=self.cfg.DATA.std) | |||||
| ]) | |||||
| y_list = [] | |||||
| for img in img_list: | |||||
| img = transforms(img) | |||||
| imgs = torch.unsqueeze(img, 0) | |||||
| imgs = imgs.to(self._device) | |||||
| imgs_x0 = self.autoencoder.encode(imgs) | |||||
| b, c, h, w = imgs_x0.shape | |||||
| aug_imgs = TF.normalize( | |||||
| F.interpolate( | |||||
| imgs.add(1).div(2), (self.cfg.Params.vit.vit_image_size, | |||||
| self.cfg.Params.vit.vit_image_size), | |||||
| mode='bilinear', | |||||
| align_corners=True), self.cfg.Params.vit.vit_mean, | |||||
| self.cfg.Params.vit.vit_std) | |||||
| uy = self.vit(aug_imgs) | |||||
| y = F.normalize(uy, p=2, dim=1) | |||||
| y_list.append(y) | |||||
| if input_type == 0: | |||||
| result = { | |||||
| 'image_data': y_list[0], | |||||
| 'c': c, | |||||
| 'h': h, | |||||
| 'w': w, | |||||
| 'type': input_type | |||||
| } | |||||
| elif input_type == 1: | |||||
| result = { | |||||
| 'image_data': y_list[0], | |||||
| 'image_data_s': y_list[1], | |||||
| 'c': c, | |||||
| 'h': h, | |||||
| 'w': w, | |||||
| 'type': input_type | |||||
| } | |||||
| return result | |||||
| @torch.no_grad() | |||||
| def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: | |||||
| type_ = input['type'] | |||||
| if type_ == 0: | |||||
| # Similar Image Generation # | |||||
| y = input['image_data'] | |||||
| # fix seed | |||||
| torch.manual_seed(1 * 8888) | |||||
| torch.cuda.manual_seed(1 * 8888) | |||||
| i_y = y.repeat(self.repetition, 1) | |||||
| # sample images | |||||
| x0 = self.diffusion.ddim_sample_loop( | |||||
| noise=torch.randn(self.repetition, input['c'], input['h'], | |||||
| input['w']).to(self._device), | |||||
| model=self.decoder, | |||||
| model_kwargs=[{ | |||||
| 'y': i_y | |||||
| }, { | |||||
| 'y': torch.zeros_like(i_y) | |||||
| }], | |||||
| guide_scale=1.0, | |||||
| clamp=None, | |||||
| ddim_timesteps=50, | |||||
| eta=1.0) | |||||
| i_gen_imgs = self.autoencoder.decode(x0) | |||||
| return {OutputKeys.OUTPUT_IMG: i_gen_imgs} | |||||
| else: | |||||
| # Interpolation # | |||||
| # get content-style pairs | |||||
| y = input['image_data'] | |||||
| y_s = input['image_data_s'] | |||||
| # fix seed | |||||
| torch.manual_seed(1 * 8888) | |||||
| torch.cuda.manual_seed(1 * 8888) | |||||
| noise = torch.randn(self.repetition, input['c'], input['h'], | |||||
| input['w']).to(self._device) | |||||
| # interpolation between y_cid and y_sid | |||||
| factors = torch.linspace(0, 1, self.repetition).unsqueeze(1).to( | |||||
| self._device) | |||||
| i_y = (1 - factors) * y + factors * y_s | |||||
| # sample images | |||||
| x0 = self.diffusion.ddim_sample_loop( | |||||
| noise=noise, | |||||
| model=self.decoder, | |||||
| model_kwargs=[{ | |||||
| 'y': i_y | |||||
| }, { | |||||
| 'y': torch.zeros_like(i_y) | |||||
| }], | |||||
| guide_scale=3.0, | |||||
| clamp=None, | |||||
| ddim_timesteps=50, | |||||
| eta=0.0) | |||||
| i_gen_imgs = self.autoencoder.decode(x0) | |||||
| return {OutputKeys.OUTPUT_IMG: i_gen_imgs} | |||||
| def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: | |||||
| return inputs | |||||
| @@ -42,6 +42,7 @@ class CVTasks(object): | |||||
| video_category = 'video-category' | video_category = 'video-category' | ||||
| image_classification_imagenet = 'image-classification-imagenet' | image_classification_imagenet = 'image-classification-imagenet' | ||||
| image_classification_dailylife = 'image-classification-dailylife' | image_classification_dailylife = 'image-classification-dailylife' | ||||
| image_to_image_generation = 'image-to-image-generation' | |||||
| class NLPTasks(object): | class NLPTasks(object): | ||||
| @@ -0,0 +1,48 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os.path as osp | |||||
| import shutil | |||||
| import unittest | |||||
| from torchvision.utils import save_image | |||||
| from modelscope.fileio import File | |||||
| from modelscope.msdatasets import MsDataset | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class Image2ImageGenerationTest(unittest.TestCase): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_modelhub(self): | |||||
| r"""We provide two generation modes, i.e., Similar Image Generation and Interpolation. | |||||
| You can pass the following parameters for different mode. | |||||
| 1. Similar Image Generation Mode: | |||||
| 2. Interpolation Mode: | |||||
| """ | |||||
| img2img_gen_pipeline = pipeline( | |||||
| Tasks.image_to_image_generation, | |||||
| model='damo/cv_latent_diffusion_image2image_generate') | |||||
| # Similar Image Generation mode | |||||
| result1 = img2img_gen_pipeline('data/test/images/img2img_input.jpg') | |||||
| # Interpolation Mode | |||||
| result2 = img2img_gen_pipeline(('data/test/images/img2img_input.jpg', | |||||
| 'data/test/images/img2img_style.jpg')) | |||||
| save_image( | |||||
| result1['output_img'].clamp(-1, 1), | |||||
| 'result1.jpg', | |||||
| range=(-1, 1), | |||||
| normalize=True, | |||||
| nrow=4) | |||||
| save_image( | |||||
| result2['output_img'].clamp(-1, 1), | |||||
| 'result2.jpg', | |||||
| range=(-1, 1), | |||||
| normalize=True, | |||||
| nrow=4) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||