|
|
|
@@ -215,8 +215,9 @@ class ImagenForTextToImageSynthesis(Model): |
|
|
|
eta=input.get('generator_ddim_eta', 0.0)) |
|
|
|
|
|
|
|
# upsampling (64->256) |
|
|
|
img = F.interpolate( |
|
|
|
img, scale_factor=4.0, mode='bilinear', align_corners=False) |
|
|
|
if not input.get('debug', False): |
|
|
|
img = F.interpolate( |
|
|
|
img, scale_factor=4.0, mode='bilinear', align_corners=False) |
|
|
|
img = self.diffusion_imagen_upsampler_256.ddim_sample_loop( |
|
|
|
noise=torch.randn_like(img), |
|
|
|
model=self.unet_imagen_upsampler_256, |
|
|
|
@@ -233,14 +234,15 @@ class ImagenForTextToImageSynthesis(Model): |
|
|
|
'context': torch.zeros_like(context), |
|
|
|
'mask': torch.zeros_like(attention_mask) |
|
|
|
}], |
|
|
|
percentile=input.get('generator_percentile', 0.995), |
|
|
|
guide_scale=input.get('generator_guide_scale', 5.0), |
|
|
|
ddim_timesteps=input.get('generator_ddim_timesteps', 50), |
|
|
|
eta=input.get('generator_ddim_eta', 0.0)) |
|
|
|
percentile=input.get('upsampler_256_percentile', 0.995), |
|
|
|
guide_scale=input.get('upsampler_256_guide_scale', 5.0), |
|
|
|
ddim_timesteps=input.get('upsampler_256_ddim_timesteps', 50), |
|
|
|
eta=input.get('upsampler_256_ddim_eta', 0.0)) |
|
|
|
|
|
|
|
# upsampling (256->1024) |
|
|
|
img = F.interpolate( |
|
|
|
img, scale_factor=4.0, mode='bilinear', align_corners=False) |
|
|
|
if not input.get('debug', False): |
|
|
|
img = F.interpolate( |
|
|
|
img, scale_factor=4.0, mode='bilinear', align_corners=False) |
|
|
|
img = self.diffusion_upsampler_1024.ddim_sample_loop( |
|
|
|
noise=torch.randn_like(img), |
|
|
|
model=self.unet_upsampler_1024, |
|
|
|
|