为diffusion模型加入dpm solver支持,相比ddim scheduler快2~6倍。
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10826722
master^2
| @@ -5,6 +5,9 @@ import math | |||||
| import torch | import torch | ||||
| from modelscope.models.multi_modal.dpm_solver_pytorch import ( | |||||
| DPM_Solver, NoiseScheduleVP, model_wrapper, model_wrapper_guided_diffusion) | |||||
| __all__ = ['GaussianDiffusion', 'beta_schedule'] | __all__ = ['GaussianDiffusion', 'beta_schedule'] | ||||
| @@ -259,6 +262,61 @@ class GaussianDiffusion(object): | |||||
| x0 = x0.clamp(-clamp, clamp) | x0 = x0.clamp(-clamp, clamp) | ||||
| return mu, var, log_var, x0 | return mu, var, log_var, x0 | ||||
| @torch.no_grad() | |||||
| def dpm_solver_sample_loop(self, | |||||
| noise, | |||||
| model, | |||||
| skip_type, | |||||
| order, | |||||
| method, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| condition_fn=None, | |||||
| guide_scale=None, | |||||
| dpm_solver_timesteps=20, | |||||
| t_start=None, | |||||
| t_end=None, | |||||
| lower_order_final=True, | |||||
| denoise_to_zero=False, | |||||
| solver_type='dpm_solver'): | |||||
| r"""Sample using DPM-Solver-based method. | |||||
| - condition_fn: for classifier-based guidance (guided-diffusion). | |||||
| - guide_scale: for classifier-free guidance (glide/dalle-2). | |||||
| Please check all the parameters in `dpm_solver.sample` before using. | |||||
| """ | |||||
| noise_schedule = NoiseScheduleVP( | |||||
| schedule='discrete', betas=self.betas.float()) | |||||
| model_fn = model_wrapper_guided_diffusion( | |||||
| model=model, | |||||
| noise_schedule=noise_schedule, | |||||
| var_type=self.var_type, | |||||
| mean_type=self.mean_type, | |||||
| model_kwargs=model_kwargs, | |||||
| clamp=clamp, | |||||
| percentile=percentile, | |||||
| rescale_timesteps=self.rescale_timesteps, | |||||
| num_timesteps=self.num_timesteps, | |||||
| guide_scale=guide_scale, | |||||
| condition_fn=condition_fn, | |||||
| ) | |||||
| dpm_solver = DPM_Solver( | |||||
| model_fn=model_fn, | |||||
| noise_schedule=noise_schedule, | |||||
| ) | |||||
| xt = dpm_solver.sample( | |||||
| noise, | |||||
| steps=dpm_solver_timesteps, | |||||
| order=order, | |||||
| skip_type=skip_type, | |||||
| method=method, | |||||
| solver_type=solver_type, | |||||
| t_start=t_start, | |||||
| t_end=t_end, | |||||
| lower_order_final=lower_order_final, | |||||
| denoise_to_zero=denoise_to_zero) | |||||
| return xt | |||||
| @torch.no_grad() | @torch.no_grad() | ||||
| def ddim_sample(self, | def ddim_sample(self, | ||||
| xt, | xt, | ||||
| @@ -197,60 +197,155 @@ class DiffusionForTextToImageSynthesis(Model): | |||||
| attention_mask=attention_mask) | attention_mask=attention_mask) | ||||
| context = context[-1] | context = context[-1] | ||||
| # generation | |||||
| img = self.diffusion_generator.ddim_sample_loop( | |||||
| noise=torch.randn(1, 3, 64, 64).to(self.device), | |||||
| model=self.unet_generator, | |||||
| model_kwargs=[{ | |||||
| 'y': y, | |||||
| 'context': context, | |||||
| 'mask': attention_mask | |||||
| }, { | |||||
| 'y': torch.zeros_like(y), | |||||
| 'context': torch.zeros_like(context), | |||||
| 'mask': attention_mask | |||||
| }], | |||||
| percentile=input.get('generator_percentile', 0.995), | |||||
| guide_scale=input.get('generator_guide_scale', 5.0), | |||||
| ddim_timesteps=input.get('generator_ddim_timesteps', 250), | |||||
| eta=input.get('generator_ddim_eta', 0.0)) | |||||
| # upsampling (64->256) | |||||
| if not input.get('debug', False): | |||||
| img = F.interpolate( | |||||
| img, scale_factor=4.0, mode='bilinear', align_corners=False) | |||||
| img = self.diffusion_upsampler_256.ddim_sample_loop( | |||||
| noise=torch.randn_like(img), | |||||
| model=self.unet_upsampler_256, | |||||
| model_kwargs=[{ | |||||
| 'lx': img, | |||||
| 'lt': torch.zeros(1).to(self.device), | |||||
| 'y': y, | |||||
| 'context': context, | |||||
| 'mask': attention_mask | |||||
| }, { | |||||
| 'lx': img, | |||||
| 'lt': torch.zeros(1).to(self.device), | |||||
| 'y': torch.zeros_like(y), | |||||
| 'context': torch.zeros_like(context), | |||||
| 'mask': torch.zeros_like(attention_mask) | |||||
| }], | |||||
| percentile=input.get('upsampler_256_percentile', 0.995), | |||||
| guide_scale=input.get('upsampler_256_guide_scale', 5.0), | |||||
| ddim_timesteps=input.get('upsampler_256_ddim_timesteps', 50), | |||||
| eta=input.get('upsampler_256_ddim_eta', 0.0)) | |||||
| # upsampling (256->1024) | |||||
| if not input.get('debug', False): | |||||
| img = F.interpolate( | |||||
| img, scale_factor=4.0, mode='bilinear', align_corners=False) | |||||
| img = self.diffusion_upsampler_1024.ddim_sample_loop( | |||||
| noise=torch.randn_like(img), | |||||
| model=self.unet_upsampler_1024, | |||||
| model_kwargs={'concat': img}, | |||||
| percentile=input.get('upsampler_1024_percentile', 0.995), | |||||
| ddim_timesteps=input.get('upsampler_1024_ddim_timesteps', 20), | |||||
| eta=input.get('upsampler_1024_ddim_eta', 0.0)) | |||||
| # choose a proper solver | |||||
| solver = input.get('solver', 'dpm-solver') | |||||
| if solver == 'dpm-solver': | |||||
| # generation | |||||
| img = self.diffusion_generator.dpm_solver_sample_loop( | |||||
| noise=torch.randn(1, 3, 64, 64).to(self.device), | |||||
| model=self.unet_generator, | |||||
| model_kwargs=[{ | |||||
| 'y': y, | |||||
| 'context': context, | |||||
| 'mask': attention_mask | |||||
| }, { | |||||
| 'y': torch.zeros_like(y), | |||||
| 'context': torch.zeros_like(context), | |||||
| 'mask': attention_mask | |||||
| }], | |||||
| percentile=input.get('generator_percentile', 0.995), | |||||
| guide_scale=input.get('generator_guide_scale', 5.0), | |||||
| dpm_solver_timesteps=input.get('dpm_solver_timesteps', 20), | |||||
| order=3, | |||||
| skip_type='logSNR', | |||||
| method='singlestep', | |||||
| t_start=0.9946) | |||||
| # upsampling (64->256) | |||||
| if not input.get('debug', False): | |||||
| img = F.interpolate( | |||||
| img, | |||||
| scale_factor=4.0, | |||||
| mode='bilinear', | |||||
| align_corners=False) | |||||
| img = self.diffusion_upsampler_256.dpm_solver_sample_loop( | |||||
| noise=torch.randn_like(img), | |||||
| model=self.unet_upsampler_256, | |||||
| model_kwargs=[{ | |||||
| 'lx': img, | |||||
| 'lt': torch.zeros(1).to(self.device), | |||||
| 'y': y, | |||||
| 'context': context, | |||||
| 'mask': attention_mask | |||||
| }, { | |||||
| 'lx': img, | |||||
| 'lt': torch.zeros(1).to(self.device), | |||||
| 'y': torch.zeros_like(y), | |||||
| 'context': torch.zeros_like(context), | |||||
| 'mask': torch.zeros_like(attention_mask) | |||||
| }], | |||||
| percentile=input.get('upsampler_256_percentile', 0.995), | |||||
| guide_scale=input.get('upsampler_256_guide_scale', 5.0), | |||||
| dpm_solver_timesteps=input.get('dpm_solver_timesteps', 20), | |||||
| order=3, | |||||
| skip_type='logSNR', | |||||
| method='singlestep', | |||||
| t_start=0.9946) | |||||
| # upsampling (256->1024) | |||||
| if not input.get('debug', False): | |||||
| img = F.interpolate( | |||||
| img, | |||||
| scale_factor=4.0, | |||||
| mode='bilinear', | |||||
| align_corners=False) | |||||
| img = self.diffusion_upsampler_1024.dpm_solver_sample_loop( | |||||
| noise=torch.randn_like(img), | |||||
| model=self.unet_upsampler_256, | |||||
| model_kwargs=[{ | |||||
| 'lx': img, | |||||
| 'lt': torch.zeros(1).to(self.device), | |||||
| 'y': y, | |||||
| 'context': context, | |||||
| 'mask': attention_mask | |||||
| }, { | |||||
| 'lx': img, | |||||
| 'lt': torch.zeros(1).to(self.device), | |||||
| 'y': torch.zeros_like(y), | |||||
| 'context': torch.zeros_like(context), | |||||
| 'mask': torch.zeros_like(attention_mask) | |||||
| }], | |||||
| percentile=input.get('upsampler_256_percentile', 0.995), | |||||
| guide_scale=input.get('upsampler_256_guide_scale', 5.0), | |||||
| dpm_solver_timesteps=input.get('dpm_solver_timesteps', 10), | |||||
| order=3, | |||||
| skip_type='logSNR', | |||||
| method='singlestep', | |||||
| t_start=None) | |||||
| elif solver == 'ddim': | |||||
| # generation | |||||
| img = self.diffusion_generator.ddim_sample_loop( | |||||
| noise=torch.randn(1, 3, 64, 64).to(self.device), | |||||
| model=self.unet_generator, | |||||
| model_kwargs=[{ | |||||
| 'y': y, | |||||
| 'context': context, | |||||
| 'mask': attention_mask | |||||
| }, { | |||||
| 'y': torch.zeros_like(y), | |||||
| 'context': torch.zeros_like(context), | |||||
| 'mask': attention_mask | |||||
| }], | |||||
| percentile=input.get('generator_percentile', 0.995), | |||||
| guide_scale=input.get('generator_guide_scale', 5.0), | |||||
| ddim_timesteps=input.get('generator_ddim_timesteps', 250), | |||||
| eta=input.get('generator_ddim_eta', 0.0)) | |||||
| # upsampling (64->256) | |||||
| if not input.get('debug', False): | |||||
| img = F.interpolate( | |||||
| img, | |||||
| scale_factor=4.0, | |||||
| mode='bilinear', | |||||
| align_corners=False) | |||||
| img = self.diffusion_upsampler_256.ddim_sample_loop( | |||||
| noise=torch.randn_like(img), | |||||
| model=self.unet_upsampler_256, | |||||
| model_kwargs=[{ | |||||
| 'lx': img, | |||||
| 'lt': torch.zeros(1).to(self.device), | |||||
| 'y': y, | |||||
| 'context': context, | |||||
| 'mask': attention_mask | |||||
| }, { | |||||
| 'lx': img, | |||||
| 'lt': torch.zeros(1).to(self.device), | |||||
| 'y': torch.zeros_like(y), | |||||
| 'context': torch.zeros_like(context), | |||||
| 'mask': torch.zeros_like(attention_mask) | |||||
| }], | |||||
| percentile=input.get('upsampler_256_percentile', 0.995), | |||||
| guide_scale=input.get('upsampler_256_guide_scale', 5.0), | |||||
| ddim_timesteps=input.get('upsampler_256_ddim_timesteps', 50), | |||||
| eta=input.get('upsampler_256_ddim_eta', 0.0)) | |||||
| # upsampling (256->1024) | |||||
| if not input.get('debug', False): | |||||
| img = F.interpolate( | |||||
| img, | |||||
| scale_factor=4.0, | |||||
| mode='bilinear', | |||||
| align_corners=False) | |||||
| img = self.diffusion_upsampler_1024.ddim_sample_loop( | |||||
| noise=torch.randn_like(img), | |||||
| model=self.unet_upsampler_1024, | |||||
| model_kwargs={'concat': img}, | |||||
| percentile=input.get('upsampler_1024_percentile', 0.995), | |||||
| ddim_timesteps=input.get('upsampler_1024_ddim_timesteps', 20), | |||||
| eta=input.get('upsampler_1024_ddim_eta', 0.0)) | |||||
| else: | |||||
| raise ValueError( | |||||
| 'currently only supports "ddim" and "dpm-solve" solvers') | |||||
| # output | # output | ||||
| img = img.clamp(-1, 1).add(1).mul(127.5).squeeze(0).permute( | img = img.clamp(-1, 1).add(1).mul(127.5).squeeze(0).permute( | ||||
| @@ -6,6 +6,9 @@ import math | |||||
| import torch | import torch | ||||
| from modelscope.models.multi_modal.dpm_solver_pytorch import ( | |||||
| DPM_Solver, NoiseScheduleVP, model_wrapper, model_wrapper_guided_diffusion) | |||||
| __all__ = ['GaussianDiffusion', 'beta_schedule'] | __all__ = ['GaussianDiffusion', 'beta_schedule'] | ||||
| @@ -279,6 +282,61 @@ class GaussianDiffusion(object): | |||||
| x0 = x0.clamp(-clamp, clamp) | x0 = x0.clamp(-clamp, clamp) | ||||
| return mu, var, log_var, x0 | return mu, var, log_var, x0 | ||||
| @torch.no_grad() | |||||
| def dpm_solver_sample_loop(self, | |||||
| noise, | |||||
| model, | |||||
| skip_type, | |||||
| order, | |||||
| method, | |||||
| model_kwargs={}, | |||||
| clamp=None, | |||||
| percentile=None, | |||||
| condition_fn=None, | |||||
| guide_scale=None, | |||||
| dpm_solver_timesteps=20, | |||||
| t_start=None, | |||||
| t_end=None, | |||||
| lower_order_final=True, | |||||
| denoise_to_zero=False, | |||||
| solver_type='dpm_solver'): | |||||
| r"""Sample using DPM-Solver-based method. | |||||
| - condition_fn: for classifier-based guidance (guided-diffusion). | |||||
| - guide_scale: for classifier-free guidance (glide/dalle-2). | |||||
| Please check all the parameters in `dpm_solver.sample` before using. | |||||
| """ | |||||
| noise_schedule = NoiseScheduleVP( | |||||
| schedule='discrete', betas=self.betas.float()) | |||||
| model_fn = model_wrapper_guided_diffusion( | |||||
| model=model, | |||||
| noise_schedule=noise_schedule, | |||||
| var_type=self.var_type, | |||||
| mean_type=self.mean_type, | |||||
| model_kwargs=model_kwargs, | |||||
| clamp=clamp, | |||||
| percentile=percentile, | |||||
| rescale_timesteps=self.rescale_timesteps, | |||||
| num_timesteps=self.num_timesteps, | |||||
| guide_scale=guide_scale, | |||||
| condition_fn=condition_fn, | |||||
| ) | |||||
| dpm_solver = DPM_Solver( | |||||
| model_fn=model_fn, | |||||
| noise_schedule=noise_schedule, | |||||
| ) | |||||
| xt = dpm_solver.sample( | |||||
| noise, | |||||
| steps=dpm_solver_timesteps, | |||||
| order=order, | |||||
| skip_type=skip_type, | |||||
| method=method, | |||||
| solver_type=solver_type, | |||||
| t_start=t_start, | |||||
| t_end=t_end, | |||||
| lower_order_final=lower_order_final, | |||||
| denoise_to_zero=denoise_to_zero) | |||||
| return xt | |||||
| @torch.no_grad() | @torch.no_grad() | ||||
| def ddim_sample(self, | def ddim_sample(self, | ||||
| xt, | xt, | ||||
| @@ -95,7 +95,8 @@ class UnCLIP(nn.Module): | |||||
| eta_prior=0.0, | eta_prior=0.0, | ||||
| eta_64=0.0, | eta_64=0.0, | ||||
| eta_256=0.0, | eta_256=0.0, | ||||
| eta_1024=0.0): | |||||
| eta_1024=0.0, | |||||
| solver='dpm-solver'): | |||||
| device = next(self.parameters()).device | device = next(self.parameters()).device | ||||
| # check params | # check params | ||||
| @@ -141,71 +142,160 @@ class UnCLIP(nn.Module): | |||||
| # synthesis | # synthesis | ||||
| with amp.autocast(enabled=True): | with amp.autocast(enabled=True): | ||||
| # prior | |||||
| x0 = self.prior_diffusion.ddim_sample_loop( | |||||
| noise=torch.randn_like(y), | |||||
| model=self.prior, | |||||
| model_kwargs=[{ | |||||
| 'y': y | |||||
| }, { | |||||
| 'y': zero_y | |||||
| }], | |||||
| guide_scale=guide_prior, | |||||
| ddim_timesteps=timesteps_prior, | |||||
| eta=eta_prior) | |||||
| # choose a proper solver | |||||
| if solver == 'dpm-solver': | |||||
| # prior | |||||
| x0 = self.prior_diffusion.dpm_solver_sample_loop( | |||||
| noise=torch.randn_like(y), | |||||
| model=self.prior, | |||||
| model_kwargs=[{ | |||||
| 'y': y | |||||
| }, { | |||||
| 'y': zero_y | |||||
| }], | |||||
| guide_scale=guide_prior, | |||||
| dpm_solver_timesteps=timesteps_prior, | |||||
| order=3, | |||||
| skip_type='logSNR', | |||||
| method='singlestep', | |||||
| t_start=0.9946) | |||||
| # decoder | |||||
| imgs64 = self.decoder_diffusion.ddim_sample_loop( | |||||
| noise=torch.randn(batch_size, 3, 64, 64).to(device), | |||||
| model=self.decoder, | |||||
| model_kwargs=[{ | |||||
| 'y': x0 | |||||
| }, { | |||||
| 'y': torch.zeros_like(x0) | |||||
| }], | |||||
| guide_scale=guide_64, | |||||
| percentile=0.995, | |||||
| ddim_timesteps=timesteps_64, | |||||
| eta=eta_64).clamp_(-1, 1) | |||||
| # decoder | |||||
| imgs64 = self.decoder_diffusion.dpm_solver_sample_loop( | |||||
| noise=torch.randn(batch_size, 3, 64, 64).to(device), | |||||
| model=self.decoder, | |||||
| model_kwargs=[{ | |||||
| 'y': x0 | |||||
| }, { | |||||
| 'y': torch.zeros_like(x0) | |||||
| }], | |||||
| guide_scale=guide_64, | |||||
| percentile=0.995, | |||||
| dpm_solver_timesteps=timesteps_64, | |||||
| order=3, | |||||
| skip_type='logSNR', | |||||
| method='singlestep', | |||||
| t_start=0.9946).clamp_(-1, 1) | |||||
| # upsampler256 | |||||
| imgs256 = F.interpolate( | |||||
| imgs64, scale_factor=4.0, mode='bilinear', align_corners=False) | |||||
| imgs256 = self.upsampler256_diffusion.ddim_sample_loop( | |||||
| noise=torch.randn_like(imgs256), | |||||
| model=self.upsampler256, | |||||
| model_kwargs=[{ | |||||
| 'y': y, | |||||
| 'concat': imgs256 | |||||
| }, { | |||||
| 'y': zero_y, | |||||
| 'concat': imgs256 | |||||
| }], | |||||
| guide_scale=guide_256, | |||||
| percentile=0.995, | |||||
| ddim_timesteps=timesteps_256, | |||||
| eta=eta_256).clamp_(-1, 1) | |||||
| # upsampler256 | |||||
| imgs256 = F.interpolate( | |||||
| imgs64, | |||||
| scale_factor=4.0, | |||||
| mode='bilinear', | |||||
| align_corners=False) | |||||
| imgs256 = self.upsampler256_diffusion.dpm_solver_sample_loop( | |||||
| noise=torch.randn_like(imgs256), | |||||
| model=self.upsampler256, | |||||
| model_kwargs=[{ | |||||
| 'y': y, | |||||
| 'concat': imgs256 | |||||
| }, { | |||||
| 'y': zero_y, | |||||
| 'concat': imgs256 | |||||
| }], | |||||
| guide_scale=guide_256, | |||||
| percentile=0.995, | |||||
| dpm_solver_timesteps=timesteps_256, | |||||
| order=3, | |||||
| skip_type='logSNR', | |||||
| method='singlestep', | |||||
| t_start=0.9946).clamp_(-1, 1) | |||||
| # upsampler1024 | |||||
| imgs1024 = F.interpolate( | |||||
| imgs256, | |||||
| scale_factor=4.0, | |||||
| mode='bilinear', | |||||
| align_corners=False) | |||||
| imgs1024 = self.upsampler1024_diffusion.ddim_sample_loop( | |||||
| noise=torch.randn_like(imgs1024), | |||||
| model=self.upsampler1024, | |||||
| model_kwargs=[{ | |||||
| 'y': y, | |||||
| 'concat': imgs1024 | |||||
| }, { | |||||
| 'y': zero_y, | |||||
| 'concat': imgs1024 | |||||
| }], | |||||
| guide_scale=guide_1024, | |||||
| percentile=0.995, | |||||
| ddim_timesteps=timesteps_1024, | |||||
| eta=eta_1024).clamp_(-1, 1) | |||||
| # upsampler1024 | |||||
| imgs1024 = F.interpolate( | |||||
| imgs256, | |||||
| scale_factor=4.0, | |||||
| mode='bilinear', | |||||
| align_corners=False) | |||||
| imgs1024 = self.upsampler1024_diffusion.dpm_solver_sample_loop( | |||||
| noise=torch.randn_like(imgs1024), | |||||
| model=self.upsampler1024, | |||||
| model_kwargs=[{ | |||||
| 'y': y, | |||||
| 'concat': imgs1024 | |||||
| }, { | |||||
| 'y': zero_y, | |||||
| 'concat': imgs1024 | |||||
| }], | |||||
| guide_scale=guide_1024, | |||||
| percentile=0.995, | |||||
| dpm_solver_timesteps=timesteps_1024, | |||||
| order=3, | |||||
| skip_type='logSNR', | |||||
| method='singlestep', | |||||
| t_start=None).clamp_(-1, 1) | |||||
| elif solver == 'ddim': | |||||
| # prior | |||||
| x0 = self.prior_diffusion.ddim_sample_loop( | |||||
| noise=torch.randn_like(y), | |||||
| model=self.prior, | |||||
| model_kwargs=[{ | |||||
| 'y': y | |||||
| }, { | |||||
| 'y': zero_y | |||||
| }], | |||||
| guide_scale=guide_prior, | |||||
| ddim_timesteps=timesteps_prior, | |||||
| eta=eta_prior) | |||||
| # decoder | |||||
| imgs64 = self.decoder_diffusion.ddim_sample_loop( | |||||
| noise=torch.randn(batch_size, 3, 64, 64).to(device), | |||||
| model=self.decoder, | |||||
| model_kwargs=[{ | |||||
| 'y': x0 | |||||
| }, { | |||||
| 'y': torch.zeros_like(x0) | |||||
| }], | |||||
| guide_scale=guide_64, | |||||
| percentile=0.995, | |||||
| ddim_timesteps=timesteps_64, | |||||
| eta=eta_64).clamp_(-1, 1) | |||||
| # upsampler256 | |||||
| imgs256 = F.interpolate( | |||||
| imgs64, | |||||
| scale_factor=4.0, | |||||
| mode='bilinear', | |||||
| align_corners=False) | |||||
| imgs256 = self.upsampler256_diffusion.ddim_sample_loop( | |||||
| noise=torch.randn_like(imgs256), | |||||
| model=self.upsampler256, | |||||
| model_kwargs=[{ | |||||
| 'y': y, | |||||
| 'concat': imgs256 | |||||
| }, { | |||||
| 'y': zero_y, | |||||
| 'concat': imgs256 | |||||
| }], | |||||
| guide_scale=guide_256, | |||||
| percentile=0.995, | |||||
| ddim_timesteps=timesteps_256, | |||||
| eta=eta_256).clamp_(-1, 1) | |||||
| # upsampler1024 | |||||
| imgs1024 = F.interpolate( | |||||
| imgs256, | |||||
| scale_factor=4.0, | |||||
| mode='bilinear', | |||||
| align_corners=False) | |||||
| imgs1024 = self.upsampler1024_diffusion.ddim_sample_loop( | |||||
| noise=torch.randn_like(imgs1024), | |||||
| model=self.upsampler1024, | |||||
| model_kwargs=[{ | |||||
| 'y': y, | |||||
| 'concat': imgs1024 | |||||
| }, { | |||||
| 'y': zero_y, | |||||
| 'concat': imgs1024 | |||||
| }], | |||||
| guide_scale=guide_1024, | |||||
| percentile=0.995, | |||||
| ddim_timesteps=timesteps_1024, | |||||
| eta=eta_1024).clamp_(-1, 1) | |||||
| else: | |||||
| raise ValueError( | |||||
| 'currently only supports "ddim" and "dpm-solve" solvers') | |||||
| # output ([B, C, H, W] within range [0, 1]) | # output ([B, C, H, W] within range [0, 1]) | ||||
| imgs1024 = imgs1024.add_(1).mul_(255 / 2.0).permute(0, 2, 3, 1).cpu() | imgs1024 = imgs1024.add_(1).mul_(255 / 2.0).permute(0, 2, 3, 1).cpu() | ||||
| @@ -245,7 +335,7 @@ class MultiStageDiffusionForTextToImageSynthesis(TorchModel): | |||||
| if 'text' not in input: | if 'text' not in input: | ||||
| raise ValueError('input should contain "text", but not found') | raise ValueError('input should contain "text", but not found') | ||||
| # ddim sampling | |||||
| # sampling | |||||
| imgs = self.model.synthesis( | imgs = self.model.synthesis( | ||||
| text=input.get('text'), | text=input.get('text'), | ||||
| tokenizer=input.get('tokenizer', 'clip'), | tokenizer=input.get('tokenizer', 'clip'), | ||||
| @@ -261,6 +351,7 @@ class MultiStageDiffusionForTextToImageSynthesis(TorchModel): | |||||
| eta_prior=input.get('eta_prior', 0.0), | eta_prior=input.get('eta_prior', 0.0), | ||||
| eta_64=input.get('eta_64', 0.0), | eta_64=input.get('eta_64', 0.0), | ||||
| eta_256=input.get('eta_256', 0.0), | eta_256=input.get('eta_256', 0.0), | ||||
| eta_1024=input.get('eta_1024', 0.0)) | |||||
| eta_1024=input.get('eta_1024', 0.0), | |||||
| solver=input.get('solver', 'dpm-solver')) | |||||
| imgs = [np.array(u)[..., ::-1] for u in imgs] | imgs = [np.array(u)[..., ::-1] for u in imgs] | ||||
| return imgs | return imgs | ||||
| @@ -51,6 +51,16 @@ class TextToImageSynthesisTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| self.test_text)[OutputKeys.OUTPUT_IMG] | self.test_text)[OutputKeys.OUTPUT_IMG] | ||||
| print(np.sum(np.abs(img))) | print(np.sum(np.abs(img))) | ||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_with_model_from_modelhub_dpm_solver(self): | |||||
| test_text.update({'solver': 'dpm-solver'}) | |||||
| model = Model.from_pretrained(self.model_id) | |||||
| pipe_line_text_to_image_synthesis = pipeline( | |||||
| task=Tasks.text_to_image_synthesis, model=model) | |||||
| img = pipe_line_text_to_image_synthesis( | |||||
| self.test_text)[OutputKeys.OUTPUT_IMG] | |||||
| print(np.sum(np.abs(img))) | |||||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | @unittest.skip('demo compatibility test is only enabled on a needed-basis') | ||||
| def test_demo_compatibility(self): | def test_demo_compatibility(self): | ||||
| self.compatibility_check() | self.compatibility_check() | ||||