为diffusion模型加入dpm solver支持,相比ddim scheduler快2~6倍。
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10826722
master^2
| @@ -5,6 +5,9 @@ import math | |||
| import torch | |||
| from modelscope.models.multi_modal.dpm_solver_pytorch import ( | |||
| DPM_Solver, NoiseScheduleVP, model_wrapper, model_wrapper_guided_diffusion) | |||
| __all__ = ['GaussianDiffusion', 'beta_schedule'] | |||
| @@ -259,6 +262,61 @@ class GaussianDiffusion(object): | |||
| x0 = x0.clamp(-clamp, clamp) | |||
| return mu, var, log_var, x0 | |||
| @torch.no_grad() | |||
| def dpm_solver_sample_loop(self, | |||
| noise, | |||
| model, | |||
| skip_type, | |||
| order, | |||
| method, | |||
| model_kwargs={}, | |||
| clamp=None, | |||
| percentile=None, | |||
| condition_fn=None, | |||
| guide_scale=None, | |||
| dpm_solver_timesteps=20, | |||
| t_start=None, | |||
| t_end=None, | |||
| lower_order_final=True, | |||
| denoise_to_zero=False, | |||
| solver_type='dpm_solver'): | |||
| r"""Sample using DPM-Solver-based method. | |||
| - condition_fn: for classifier-based guidance (guided-diffusion). | |||
| - guide_scale: for classifier-free guidance (glide/dalle-2). | |||
| Please check all the parameters in `dpm_solver.sample` before using. | |||
| """ | |||
| noise_schedule = NoiseScheduleVP( | |||
| schedule='discrete', betas=self.betas.float()) | |||
| model_fn = model_wrapper_guided_diffusion( | |||
| model=model, | |||
| noise_schedule=noise_schedule, | |||
| var_type=self.var_type, | |||
| mean_type=self.mean_type, | |||
| model_kwargs=model_kwargs, | |||
| clamp=clamp, | |||
| percentile=percentile, | |||
| rescale_timesteps=self.rescale_timesteps, | |||
| num_timesteps=self.num_timesteps, | |||
| guide_scale=guide_scale, | |||
| condition_fn=condition_fn, | |||
| ) | |||
| dpm_solver = DPM_Solver( | |||
| model_fn=model_fn, | |||
| noise_schedule=noise_schedule, | |||
| ) | |||
| xt = dpm_solver.sample( | |||
| noise, | |||
| steps=dpm_solver_timesteps, | |||
| order=order, | |||
| skip_type=skip_type, | |||
| method=method, | |||
| solver_type=solver_type, | |||
| t_start=t_start, | |||
| t_end=t_end, | |||
| lower_order_final=lower_order_final, | |||
| denoise_to_zero=denoise_to_zero) | |||
| return xt | |||
| @torch.no_grad() | |||
| def ddim_sample(self, | |||
| xt, | |||
| @@ -197,60 +197,155 @@ class DiffusionForTextToImageSynthesis(Model): | |||
| attention_mask=attention_mask) | |||
| context = context[-1] | |||
| # generation | |||
| img = self.diffusion_generator.ddim_sample_loop( | |||
| noise=torch.randn(1, 3, 64, 64).to(self.device), | |||
| model=self.unet_generator, | |||
| model_kwargs=[{ | |||
| 'y': y, | |||
| 'context': context, | |||
| 'mask': attention_mask | |||
| }, { | |||
| 'y': torch.zeros_like(y), | |||
| 'context': torch.zeros_like(context), | |||
| 'mask': attention_mask | |||
| }], | |||
| percentile=input.get('generator_percentile', 0.995), | |||
| guide_scale=input.get('generator_guide_scale', 5.0), | |||
| ddim_timesteps=input.get('generator_ddim_timesteps', 250), | |||
| eta=input.get('generator_ddim_eta', 0.0)) | |||
| # upsampling (64->256) | |||
| if not input.get('debug', False): | |||
| img = F.interpolate( | |||
| img, scale_factor=4.0, mode='bilinear', align_corners=False) | |||
| img = self.diffusion_upsampler_256.ddim_sample_loop( | |||
| noise=torch.randn_like(img), | |||
| model=self.unet_upsampler_256, | |||
| model_kwargs=[{ | |||
| 'lx': img, | |||
| 'lt': torch.zeros(1).to(self.device), | |||
| 'y': y, | |||
| 'context': context, | |||
| 'mask': attention_mask | |||
| }, { | |||
| 'lx': img, | |||
| 'lt': torch.zeros(1).to(self.device), | |||
| 'y': torch.zeros_like(y), | |||
| 'context': torch.zeros_like(context), | |||
| 'mask': torch.zeros_like(attention_mask) | |||
| }], | |||
| percentile=input.get('upsampler_256_percentile', 0.995), | |||
| guide_scale=input.get('upsampler_256_guide_scale', 5.0), | |||
| ddim_timesteps=input.get('upsampler_256_ddim_timesteps', 50), | |||
| eta=input.get('upsampler_256_ddim_eta', 0.0)) | |||
| # upsampling (256->1024) | |||
| if not input.get('debug', False): | |||
| img = F.interpolate( | |||
| img, scale_factor=4.0, mode='bilinear', align_corners=False) | |||
| img = self.diffusion_upsampler_1024.ddim_sample_loop( | |||
| noise=torch.randn_like(img), | |||
| model=self.unet_upsampler_1024, | |||
| model_kwargs={'concat': img}, | |||
| percentile=input.get('upsampler_1024_percentile', 0.995), | |||
| ddim_timesteps=input.get('upsampler_1024_ddim_timesteps', 20), | |||
| eta=input.get('upsampler_1024_ddim_eta', 0.0)) | |||
| # choose a proper solver | |||
| solver = input.get('solver', 'dpm-solver') | |||
| if solver == 'dpm-solver': | |||
| # generation | |||
| img = self.diffusion_generator.dpm_solver_sample_loop( | |||
| noise=torch.randn(1, 3, 64, 64).to(self.device), | |||
| model=self.unet_generator, | |||
| model_kwargs=[{ | |||
| 'y': y, | |||
| 'context': context, | |||
| 'mask': attention_mask | |||
| }, { | |||
| 'y': torch.zeros_like(y), | |||
| 'context': torch.zeros_like(context), | |||
| 'mask': attention_mask | |||
| }], | |||
| percentile=input.get('generator_percentile', 0.995), | |||
| guide_scale=input.get('generator_guide_scale', 5.0), | |||
| dpm_solver_timesteps=input.get('dpm_solver_timesteps', 20), | |||
| order=3, | |||
| skip_type='logSNR', | |||
| method='singlestep', | |||
| t_start=0.9946) | |||
| # upsampling (64->256) | |||
| if not input.get('debug', False): | |||
| img = F.interpolate( | |||
| img, | |||
| scale_factor=4.0, | |||
| mode='bilinear', | |||
| align_corners=False) | |||
| img = self.diffusion_upsampler_256.dpm_solver_sample_loop( | |||
| noise=torch.randn_like(img), | |||
| model=self.unet_upsampler_256, | |||
| model_kwargs=[{ | |||
| 'lx': img, | |||
| 'lt': torch.zeros(1).to(self.device), | |||
| 'y': y, | |||
| 'context': context, | |||
| 'mask': attention_mask | |||
| }, { | |||
| 'lx': img, | |||
| 'lt': torch.zeros(1).to(self.device), | |||
| 'y': torch.zeros_like(y), | |||
| 'context': torch.zeros_like(context), | |||
| 'mask': torch.zeros_like(attention_mask) | |||
| }], | |||
| percentile=input.get('upsampler_256_percentile', 0.995), | |||
| guide_scale=input.get('upsampler_256_guide_scale', 5.0), | |||
| dpm_solver_timesteps=input.get('dpm_solver_timesteps', 20), | |||
| order=3, | |||
| skip_type='logSNR', | |||
| method='singlestep', | |||
| t_start=0.9946) | |||
| # upsampling (256->1024) | |||
| if not input.get('debug', False): | |||
| img = F.interpolate( | |||
| img, | |||
| scale_factor=4.0, | |||
| mode='bilinear', | |||
| align_corners=False) | |||
| img = self.diffusion_upsampler_1024.dpm_solver_sample_loop( | |||
| noise=torch.randn_like(img), | |||
| model=self.unet_upsampler_256, | |||
| model_kwargs=[{ | |||
| 'lx': img, | |||
| 'lt': torch.zeros(1).to(self.device), | |||
| 'y': y, | |||
| 'context': context, | |||
| 'mask': attention_mask | |||
| }, { | |||
| 'lx': img, | |||
| 'lt': torch.zeros(1).to(self.device), | |||
| 'y': torch.zeros_like(y), | |||
| 'context': torch.zeros_like(context), | |||
| 'mask': torch.zeros_like(attention_mask) | |||
| }], | |||
| percentile=input.get('upsampler_256_percentile', 0.995), | |||
| guide_scale=input.get('upsampler_256_guide_scale', 5.0), | |||
| dpm_solver_timesteps=input.get('dpm_solver_timesteps', 10), | |||
| order=3, | |||
| skip_type='logSNR', | |||
| method='singlestep', | |||
| t_start=None) | |||
| elif solver == 'ddim': | |||
| # generation | |||
| img = self.diffusion_generator.ddim_sample_loop( | |||
| noise=torch.randn(1, 3, 64, 64).to(self.device), | |||
| model=self.unet_generator, | |||
| model_kwargs=[{ | |||
| 'y': y, | |||
| 'context': context, | |||
| 'mask': attention_mask | |||
| }, { | |||
| 'y': torch.zeros_like(y), | |||
| 'context': torch.zeros_like(context), | |||
| 'mask': attention_mask | |||
| }], | |||
| percentile=input.get('generator_percentile', 0.995), | |||
| guide_scale=input.get('generator_guide_scale', 5.0), | |||
| ddim_timesteps=input.get('generator_ddim_timesteps', 250), | |||
| eta=input.get('generator_ddim_eta', 0.0)) | |||
| # upsampling (64->256) | |||
| if not input.get('debug', False): | |||
| img = F.interpolate( | |||
| img, | |||
| scale_factor=4.0, | |||
| mode='bilinear', | |||
| align_corners=False) | |||
| img = self.diffusion_upsampler_256.ddim_sample_loop( | |||
| noise=torch.randn_like(img), | |||
| model=self.unet_upsampler_256, | |||
| model_kwargs=[{ | |||
| 'lx': img, | |||
| 'lt': torch.zeros(1).to(self.device), | |||
| 'y': y, | |||
| 'context': context, | |||
| 'mask': attention_mask | |||
| }, { | |||
| 'lx': img, | |||
| 'lt': torch.zeros(1).to(self.device), | |||
| 'y': torch.zeros_like(y), | |||
| 'context': torch.zeros_like(context), | |||
| 'mask': torch.zeros_like(attention_mask) | |||
| }], | |||
| percentile=input.get('upsampler_256_percentile', 0.995), | |||
| guide_scale=input.get('upsampler_256_guide_scale', 5.0), | |||
| ddim_timesteps=input.get('upsampler_256_ddim_timesteps', 50), | |||
| eta=input.get('upsampler_256_ddim_eta', 0.0)) | |||
| # upsampling (256->1024) | |||
| if not input.get('debug', False): | |||
| img = F.interpolate( | |||
| img, | |||
| scale_factor=4.0, | |||
| mode='bilinear', | |||
| align_corners=False) | |||
| img = self.diffusion_upsampler_1024.ddim_sample_loop( | |||
| noise=torch.randn_like(img), | |||
| model=self.unet_upsampler_1024, | |||
| model_kwargs={'concat': img}, | |||
| percentile=input.get('upsampler_1024_percentile', 0.995), | |||
| ddim_timesteps=input.get('upsampler_1024_ddim_timesteps', 20), | |||
| eta=input.get('upsampler_1024_ddim_eta', 0.0)) | |||
| else: | |||
| raise ValueError( | |||
| 'currently only supports "ddim" and "dpm-solve" solvers') | |||
| # output | |||
| img = img.clamp(-1, 1).add(1).mul(127.5).squeeze(0).permute( | |||
| @@ -6,6 +6,9 @@ import math | |||
| import torch | |||
| from modelscope.models.multi_modal.dpm_solver_pytorch import ( | |||
| DPM_Solver, NoiseScheduleVP, model_wrapper, model_wrapper_guided_diffusion) | |||
| __all__ = ['GaussianDiffusion', 'beta_schedule'] | |||
| @@ -279,6 +282,61 @@ class GaussianDiffusion(object): | |||
| x0 = x0.clamp(-clamp, clamp) | |||
| return mu, var, log_var, x0 | |||
| @torch.no_grad() | |||
| def dpm_solver_sample_loop(self, | |||
| noise, | |||
| model, | |||
| skip_type, | |||
| order, | |||
| method, | |||
| model_kwargs={}, | |||
| clamp=None, | |||
| percentile=None, | |||
| condition_fn=None, | |||
| guide_scale=None, | |||
| dpm_solver_timesteps=20, | |||
| t_start=None, | |||
| t_end=None, | |||
| lower_order_final=True, | |||
| denoise_to_zero=False, | |||
| solver_type='dpm_solver'): | |||
| r"""Sample using DPM-Solver-based method. | |||
| - condition_fn: for classifier-based guidance (guided-diffusion). | |||
| - guide_scale: for classifier-free guidance (glide/dalle-2). | |||
| Please check all the parameters in `dpm_solver.sample` before using. | |||
| """ | |||
| noise_schedule = NoiseScheduleVP( | |||
| schedule='discrete', betas=self.betas.float()) | |||
| model_fn = model_wrapper_guided_diffusion( | |||
| model=model, | |||
| noise_schedule=noise_schedule, | |||
| var_type=self.var_type, | |||
| mean_type=self.mean_type, | |||
| model_kwargs=model_kwargs, | |||
| clamp=clamp, | |||
| percentile=percentile, | |||
| rescale_timesteps=self.rescale_timesteps, | |||
| num_timesteps=self.num_timesteps, | |||
| guide_scale=guide_scale, | |||
| condition_fn=condition_fn, | |||
| ) | |||
| dpm_solver = DPM_Solver( | |||
| model_fn=model_fn, | |||
| noise_schedule=noise_schedule, | |||
| ) | |||
| xt = dpm_solver.sample( | |||
| noise, | |||
| steps=dpm_solver_timesteps, | |||
| order=order, | |||
| skip_type=skip_type, | |||
| method=method, | |||
| solver_type=solver_type, | |||
| t_start=t_start, | |||
| t_end=t_end, | |||
| lower_order_final=lower_order_final, | |||
| denoise_to_zero=denoise_to_zero) | |||
| return xt | |||
| @torch.no_grad() | |||
| def ddim_sample(self, | |||
| xt, | |||
| @@ -95,7 +95,8 @@ class UnCLIP(nn.Module): | |||
| eta_prior=0.0, | |||
| eta_64=0.0, | |||
| eta_256=0.0, | |||
| eta_1024=0.0): | |||
| eta_1024=0.0, | |||
| solver='dpm-solver'): | |||
| device = next(self.parameters()).device | |||
| # check params | |||
| @@ -141,71 +142,160 @@ class UnCLIP(nn.Module): | |||
| # synthesis | |||
| with amp.autocast(enabled=True): | |||
| # prior | |||
| x0 = self.prior_diffusion.ddim_sample_loop( | |||
| noise=torch.randn_like(y), | |||
| model=self.prior, | |||
| model_kwargs=[{ | |||
| 'y': y | |||
| }, { | |||
| 'y': zero_y | |||
| }], | |||
| guide_scale=guide_prior, | |||
| ddim_timesteps=timesteps_prior, | |||
| eta=eta_prior) | |||
| # choose a proper solver | |||
| if solver == 'dpm-solver': | |||
| # prior | |||
| x0 = self.prior_diffusion.dpm_solver_sample_loop( | |||
| noise=torch.randn_like(y), | |||
| model=self.prior, | |||
| model_kwargs=[{ | |||
| 'y': y | |||
| }, { | |||
| 'y': zero_y | |||
| }], | |||
| guide_scale=guide_prior, | |||
| dpm_solver_timesteps=timesteps_prior, | |||
| order=3, | |||
| skip_type='logSNR', | |||
| method='singlestep', | |||
| t_start=0.9946) | |||
| # decoder | |||
| imgs64 = self.decoder_diffusion.ddim_sample_loop( | |||
| noise=torch.randn(batch_size, 3, 64, 64).to(device), | |||
| model=self.decoder, | |||
| model_kwargs=[{ | |||
| 'y': x0 | |||
| }, { | |||
| 'y': torch.zeros_like(x0) | |||
| }], | |||
| guide_scale=guide_64, | |||
| percentile=0.995, | |||
| ddim_timesteps=timesteps_64, | |||
| eta=eta_64).clamp_(-1, 1) | |||
| # decoder | |||
| imgs64 = self.decoder_diffusion.dpm_solver_sample_loop( | |||
| noise=torch.randn(batch_size, 3, 64, 64).to(device), | |||
| model=self.decoder, | |||
| model_kwargs=[{ | |||
| 'y': x0 | |||
| }, { | |||
| 'y': torch.zeros_like(x0) | |||
| }], | |||
| guide_scale=guide_64, | |||
| percentile=0.995, | |||
| dpm_solver_timesteps=timesteps_64, | |||
| order=3, | |||
| skip_type='logSNR', | |||
| method='singlestep', | |||
| t_start=0.9946).clamp_(-1, 1) | |||
| # upsampler256 | |||
| imgs256 = F.interpolate( | |||
| imgs64, scale_factor=4.0, mode='bilinear', align_corners=False) | |||
| imgs256 = self.upsampler256_diffusion.ddim_sample_loop( | |||
| noise=torch.randn_like(imgs256), | |||
| model=self.upsampler256, | |||
| model_kwargs=[{ | |||
| 'y': y, | |||
| 'concat': imgs256 | |||
| }, { | |||
| 'y': zero_y, | |||
| 'concat': imgs256 | |||
| }], | |||
| guide_scale=guide_256, | |||
| percentile=0.995, | |||
| ddim_timesteps=timesteps_256, | |||
| eta=eta_256).clamp_(-1, 1) | |||
| # upsampler256 | |||
| imgs256 = F.interpolate( | |||
| imgs64, | |||
| scale_factor=4.0, | |||
| mode='bilinear', | |||
| align_corners=False) | |||
| imgs256 = self.upsampler256_diffusion.dpm_solver_sample_loop( | |||
| noise=torch.randn_like(imgs256), | |||
| model=self.upsampler256, | |||
| model_kwargs=[{ | |||
| 'y': y, | |||
| 'concat': imgs256 | |||
| }, { | |||
| 'y': zero_y, | |||
| 'concat': imgs256 | |||
| }], | |||
| guide_scale=guide_256, | |||
| percentile=0.995, | |||
| dpm_solver_timesteps=timesteps_256, | |||
| order=3, | |||
| skip_type='logSNR', | |||
| method='singlestep', | |||
| t_start=0.9946).clamp_(-1, 1) | |||
| # upsampler1024 | |||
| imgs1024 = F.interpolate( | |||
| imgs256, | |||
| scale_factor=4.0, | |||
| mode='bilinear', | |||
| align_corners=False) | |||
| imgs1024 = self.upsampler1024_diffusion.ddim_sample_loop( | |||
| noise=torch.randn_like(imgs1024), | |||
| model=self.upsampler1024, | |||
| model_kwargs=[{ | |||
| 'y': y, | |||
| 'concat': imgs1024 | |||
| }, { | |||
| 'y': zero_y, | |||
| 'concat': imgs1024 | |||
| }], | |||
| guide_scale=guide_1024, | |||
| percentile=0.995, | |||
| ddim_timesteps=timesteps_1024, | |||
| eta=eta_1024).clamp_(-1, 1) | |||
| # upsampler1024 | |||
| imgs1024 = F.interpolate( | |||
| imgs256, | |||
| scale_factor=4.0, | |||
| mode='bilinear', | |||
| align_corners=False) | |||
| imgs1024 = self.upsampler1024_diffusion.dpm_solver_sample_loop( | |||
| noise=torch.randn_like(imgs1024), | |||
| model=self.upsampler1024, | |||
| model_kwargs=[{ | |||
| 'y': y, | |||
| 'concat': imgs1024 | |||
| }, { | |||
| 'y': zero_y, | |||
| 'concat': imgs1024 | |||
| }], | |||
| guide_scale=guide_1024, | |||
| percentile=0.995, | |||
| dpm_solver_timesteps=timesteps_1024, | |||
| order=3, | |||
| skip_type='logSNR', | |||
| method='singlestep', | |||
| t_start=None).clamp_(-1, 1) | |||
| elif solver == 'ddim': | |||
| # prior | |||
| x0 = self.prior_diffusion.ddim_sample_loop( | |||
| noise=torch.randn_like(y), | |||
| model=self.prior, | |||
| model_kwargs=[{ | |||
| 'y': y | |||
| }, { | |||
| 'y': zero_y | |||
| }], | |||
| guide_scale=guide_prior, | |||
| ddim_timesteps=timesteps_prior, | |||
| eta=eta_prior) | |||
| # decoder | |||
| imgs64 = self.decoder_diffusion.ddim_sample_loop( | |||
| noise=torch.randn(batch_size, 3, 64, 64).to(device), | |||
| model=self.decoder, | |||
| model_kwargs=[{ | |||
| 'y': x0 | |||
| }, { | |||
| 'y': torch.zeros_like(x0) | |||
| }], | |||
| guide_scale=guide_64, | |||
| percentile=0.995, | |||
| ddim_timesteps=timesteps_64, | |||
| eta=eta_64).clamp_(-1, 1) | |||
| # upsampler256 | |||
| imgs256 = F.interpolate( | |||
| imgs64, | |||
| scale_factor=4.0, | |||
| mode='bilinear', | |||
| align_corners=False) | |||
| imgs256 = self.upsampler256_diffusion.ddim_sample_loop( | |||
| noise=torch.randn_like(imgs256), | |||
| model=self.upsampler256, | |||
| model_kwargs=[{ | |||
| 'y': y, | |||
| 'concat': imgs256 | |||
| }, { | |||
| 'y': zero_y, | |||
| 'concat': imgs256 | |||
| }], | |||
| guide_scale=guide_256, | |||
| percentile=0.995, | |||
| ddim_timesteps=timesteps_256, | |||
| eta=eta_256).clamp_(-1, 1) | |||
| # upsampler1024 | |||
| imgs1024 = F.interpolate( | |||
| imgs256, | |||
| scale_factor=4.0, | |||
| mode='bilinear', | |||
| align_corners=False) | |||
| imgs1024 = self.upsampler1024_diffusion.ddim_sample_loop( | |||
| noise=torch.randn_like(imgs1024), | |||
| model=self.upsampler1024, | |||
| model_kwargs=[{ | |||
| 'y': y, | |||
| 'concat': imgs1024 | |||
| }, { | |||
| 'y': zero_y, | |||
| 'concat': imgs1024 | |||
| }], | |||
| guide_scale=guide_1024, | |||
| percentile=0.995, | |||
| ddim_timesteps=timesteps_1024, | |||
| eta=eta_1024).clamp_(-1, 1) | |||
| else: | |||
| raise ValueError( | |||
| 'currently only supports "ddim" and "dpm-solve" solvers') | |||
| # output ([B, C, H, W] within range [0, 1]) | |||
| imgs1024 = imgs1024.add_(1).mul_(255 / 2.0).permute(0, 2, 3, 1).cpu() | |||
| @@ -245,7 +335,7 @@ class MultiStageDiffusionForTextToImageSynthesis(TorchModel): | |||
| if 'text' not in input: | |||
| raise ValueError('input should contain "text", but not found') | |||
| # ddim sampling | |||
| # sampling | |||
| imgs = self.model.synthesis( | |||
| text=input.get('text'), | |||
| tokenizer=input.get('tokenizer', 'clip'), | |||
| @@ -261,6 +351,7 @@ class MultiStageDiffusionForTextToImageSynthesis(TorchModel): | |||
| eta_prior=input.get('eta_prior', 0.0), | |||
| eta_64=input.get('eta_64', 0.0), | |||
| eta_256=input.get('eta_256', 0.0), | |||
| eta_1024=input.get('eta_1024', 0.0)) | |||
| eta_1024=input.get('eta_1024', 0.0), | |||
| solver=input.get('solver', 'dpm-solver')) | |||
| imgs = [np.array(u)[..., ::-1] for u in imgs] | |||
| return imgs | |||
| @@ -51,6 +51,16 @@ class TextToImageSynthesisTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| self.test_text)[OutputKeys.OUTPUT_IMG] | |||
| print(np.sum(np.abs(img))) | |||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
| def test_run_with_model_from_modelhub_dpm_solver(self): | |||
| test_text.update({'solver': 'dpm-solver'}) | |||
| model = Model.from_pretrained(self.model_id) | |||
| pipe_line_text_to_image_synthesis = pipeline( | |||
| task=Tasks.text_to_image_synthesis, model=model) | |||
| img = pipe_line_text_to_image_synthesis( | |||
| self.test_text)[OutputKeys.OUTPUT_IMG] | |||
| print(np.sum(np.abs(img))) | |||
| @unittest.skip('demo compatibility test is only enabled on a needed-basis') | |||
| def test_demo_compatibility(self): | |||
| self.compatibility_check() | |||