|
|
|
@@ -10,22 +10,23 @@ import torch.nn.functional as F |
|
|
|
from modelscope.metainfo import Models |
|
|
|
from modelscope.models import Model |
|
|
|
from modelscope.models.builder import MODELS |
|
|
|
from modelscope.models.multi_modal.imagen.diffusion import (GaussianDiffusion, |
|
|
|
beta_schedule) |
|
|
|
from modelscope.models.multi_modal.imagen.structbert import (BertConfig, |
|
|
|
BertModel) |
|
|
|
from modelscope.models.multi_modal.imagen.tokenizer import FullTokenizer |
|
|
|
from modelscope.models.multi_modal.imagen.unet_generator import ImagenGenerator |
|
|
|
from modelscope.models.multi_modal.imagen.unet_imagen_upsampler_256 import \ |
|
|
|
from modelscope.models.multi_modal.diffusion.diffusion import ( |
|
|
|
GaussianDiffusion, beta_schedule) |
|
|
|
from modelscope.models.multi_modal.diffusion.structbert import (BertConfig, |
|
|
|
BertModel) |
|
|
|
from modelscope.models.multi_modal.diffusion.tokenizer import FullTokenizer |
|
|
|
from modelscope.models.multi_modal.diffusion.unet_generator import \ |
|
|
|
DiffusionGenerator |
|
|
|
from modelscope.models.multi_modal.diffusion.unet_upsampler_256 import \ |
|
|
|
SuperResUNet256 |
|
|
|
from modelscope.models.multi_modal.imagen.unet_upsampler_1024 import \ |
|
|
|
ImagenUpsampler1024 |
|
|
|
from modelscope.models.multi_modal.diffusion.unet_upsampler_1024 import \ |
|
|
|
SuperResUNet1024 |
|
|
|
from modelscope.utils.constant import ModelFile, Tasks |
|
|
|
from modelscope.utils.logger import get_logger |
|
|
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
__all__ = ['ImagenForTextToImageSynthesis'] |
|
|
|
__all__ = ['DiffusionForTextToImageSynthesis'] |
|
|
|
|
|
|
|
|
|
|
|
def make_diffusion(schedule, |
|
|
|
@@ -68,13 +69,13 @@ class Tokenizer(object): |
|
|
|
return input_ids, segment_ids, input_mask |
|
|
|
|
|
|
|
|
|
|
|
class ImagenModel(nn.Module): |
|
|
|
class DiffusionModel(nn.Module): |
|
|
|
|
|
|
|
def __init__(self, model_dir): |
|
|
|
super(ImagenModel, self).__init__() |
|
|
|
super(DiffusionModel, self).__init__() |
|
|
|
# including text and generator config |
|
|
|
model_config = json.load( |
|
|
|
open('{}/imagen_config.json'.format(model_dir))) |
|
|
|
open('{}/model_config.json'.format(model_dir))) |
|
|
|
|
|
|
|
# text encoder |
|
|
|
text_config = model_config['text_config'] |
|
|
|
@@ -82,17 +83,15 @@ class ImagenModel(nn.Module): |
|
|
|
|
|
|
|
# generator (64x64) |
|
|
|
generator_config = model_config['generator_config'] |
|
|
|
self.unet_generator = ImagenGenerator(**generator_config) |
|
|
|
self.unet_generator = DiffusionGenerator(**generator_config) |
|
|
|
|
|
|
|
# imagen upsampler (256x256) |
|
|
|
imagen_upsampler_256_config = model_config[ |
|
|
|
'imagen_upsampler_256_config'] |
|
|
|
self.unet_imagen_upsampler_256 = SuperResUNet256( |
|
|
|
**imagen_upsampler_256_config) |
|
|
|
# upsampler (256x256) |
|
|
|
upsampler_256_config = model_config['upsampler_256_config'] |
|
|
|
self.unet_upsampler_256 = SuperResUNet256(**upsampler_256_config) |
|
|
|
|
|
|
|
# dalle2 upsampler (1024x1024) |
|
|
|
# upsampler (1024x1024) |
|
|
|
upsampler_1024_config = model_config['upsampler_1024_config'] |
|
|
|
self.unet_upsampler_1024 = ImagenUpsampler1024(**upsampler_1024_config) |
|
|
|
self.unet_upsampler_1024 = SuperResUNet1024(**upsampler_1024_config) |
|
|
|
|
|
|
|
def forward(self, noise, timesteps, input_ids, token_type_ids, |
|
|
|
attention_mask): |
|
|
|
@@ -102,39 +101,39 @@ class ImagenModel(nn.Module): |
|
|
|
attention_mask=attention_mask) |
|
|
|
context = context[-1] |
|
|
|
x = self.unet_generator(noise, timesteps, y, context, attention_mask) |
|
|
|
x = self.unet_imagen_upsampler_256(noise, timesteps, x, |
|
|
|
torch.zeros_like(timesteps), y, |
|
|
|
context, attention_mask) |
|
|
|
x = self.unet_upsampler_256(noise, timesteps, x, |
|
|
|
torch.zeros_like(timesteps), y, context, |
|
|
|
attention_mask) |
|
|
|
x = self.unet_upsampler_1024(x, t, x) |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
@MODELS.register_module( |
|
|
|
Tasks.text_to_image_synthesis, module_name=Models.imagen) |
|
|
|
class ImagenForTextToImageSynthesis(Model): |
|
|
|
Tasks.text_to_image_synthesis, module_name=Models.diffusion) |
|
|
|
class DiffusionForTextToImageSynthesis(Model): |
|
|
|
|
|
|
|
def __init__(self, model_dir, device_id=-1): |
|
|
|
super().__init__(model_dir=model_dir, device_id=device_id) |
|
|
|
imagen_model = ImagenModel(model_dir=model_dir) |
|
|
|
diffusion_model = DiffusionModel(model_dir=model_dir) |
|
|
|
pretrained_params = torch.load( |
|
|
|
osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'cpu') |
|
|
|
imagen_model.load_state_dict(pretrained_params) |
|
|
|
imagen_model.eval() |
|
|
|
diffusion_model.load_state_dict(pretrained_params) |
|
|
|
diffusion_model.eval() |
|
|
|
|
|
|
|
self.device_id = device_id |
|
|
|
if self.device_id >= 0: |
|
|
|
self.device = torch.device(f'cuda:{self.device_id}') |
|
|
|
imagen_model.to('cuda:{}'.format(self.device_id)) |
|
|
|
diffusion_model.to('cuda:{}'.format(self.device_id)) |
|
|
|
logger.info('Use GPU: {}'.format(self.device_id)) |
|
|
|
else: |
|
|
|
self.device = torch.device('cpu') |
|
|
|
logger.info('Use CPU for inference') |
|
|
|
|
|
|
|
# modules |
|
|
|
self.text_encoder = imagen_model.text_encoder |
|
|
|
self.unet_generator = imagen_model.unet_generator |
|
|
|
self.unet_imagen_upsampler_256 = imagen_model.unet_imagen_upsampler_256 |
|
|
|
self.unet_upsampler_1024 = imagen_model.unet_upsampler_1024 |
|
|
|
self.text_encoder = diffusion_model.text_encoder |
|
|
|
self.unet_generator = diffusion_model.unet_generator |
|
|
|
self.unet_upsampler_256 = diffusion_model.unet_upsampler_256 |
|
|
|
self.unet_upsampler_1024 = diffusion_model.unet_upsampler_1024 |
|
|
|
|
|
|
|
# text tokenizer |
|
|
|
vocab_path = '{}/vocab.txt'.format(model_dir) |
|
|
|
@@ -145,8 +144,8 @@ class ImagenForTextToImageSynthesis(Model): |
|
|
|
open('{}/diffusion_config.json'.format(model_dir))) |
|
|
|
self.diffusion_generator = make_diffusion( |
|
|
|
**diffusion_params['generator_config']) |
|
|
|
self.diffusion_imagen_upsampler_256 = make_diffusion( |
|
|
|
**diffusion_params['imagen_upsampler_256_config']) |
|
|
|
self.diffusion_upsampler_256 = make_diffusion( |
|
|
|
**diffusion_params['upsampler_256_config']) |
|
|
|
self.diffusion_upsampler_1024 = make_diffusion( |
|
|
|
**diffusion_params['upsampler_1024_config']) |
|
|
|
|
|
|
|
@@ -166,9 +165,9 @@ class ImagenForTextToImageSynthesis(Model): |
|
|
|
attention_mask=attention_mask) |
|
|
|
context = context[-1] |
|
|
|
x = self.unet_generator(noise, timesteps, y, context, attention_mask) |
|
|
|
x = self.unet_imagen_upsampler_256(noise, timesteps, x, |
|
|
|
torch.zeros_like(timesteps), y, |
|
|
|
context, attention_mask) |
|
|
|
x = self.unet_upsampler_256(noise, timesteps, x, |
|
|
|
torch.zeros_like(timesteps), y, context, |
|
|
|
attention_mask) |
|
|
|
x = self.unet_upsampler_1024(x, t, x) |
|
|
|
img = x.clamp(-1, 1).add(1).mul(127.5) |
|
|
|
img = img.squeeze(0).permute(1, 2, 0).cpu().numpy().astype(np.uint8) |
|
|
|
@@ -217,9 +216,9 @@ class ImagenForTextToImageSynthesis(Model): |
|
|
|
if not input.get('debug', False): |
|
|
|
img = F.interpolate( |
|
|
|
img, scale_factor=4.0, mode='bilinear', align_corners=False) |
|
|
|
img = self.diffusion_imagen_upsampler_256.ddim_sample_loop( |
|
|
|
img = self.diffusion_upsampler_256.ddim_sample_loop( |
|
|
|
noise=torch.randn_like(img), |
|
|
|
model=self.unet_imagen_upsampler_256, |
|
|
|
model=self.unet_upsampler_256, |
|
|
|
model_kwargs=[{ |
|
|
|
'lx': img, |
|
|
|
'lt': torch.zeros(1).to(self.device), |