Browse Source

[to #42322933] add dpm-solver for diffusion models

为diffusion模型加入dpm solver支持,相比ddim scheduler快2~6倍。
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10826722
master^2
lulu.lcq yingda.chen 3 years ago
parent
commit
ff171500bb
6 changed files with 1506 additions and 119 deletions
  1. +58
    -0
      modelscope/models/multi_modal/diffusion/diffusion.py
  2. +149
    -54
      modelscope/models/multi_modal/diffusion/model.py
  3. +1075
    -0
      modelscope/models/multi_modal/dpm_solver_pytorch.py
  4. +58
    -0
      modelscope/models/multi_modal/multi_stage_diffusion/gaussian_diffusion.py
  5. +156
    -65
      modelscope/models/multi_modal/multi_stage_diffusion/model.py
  6. +10
    -0
      tests/pipelines/test_text_to_image_synthesis.py

+ 58
- 0
modelscope/models/multi_modal/diffusion/diffusion.py View File

@@ -5,6 +5,9 @@ import math


import torch import torch


from modelscope.models.multi_modal.dpm_solver_pytorch import (
DPM_Solver, NoiseScheduleVP, model_wrapper, model_wrapper_guided_diffusion)

__all__ = ['GaussianDiffusion', 'beta_schedule'] __all__ = ['GaussianDiffusion', 'beta_schedule']




@@ -259,6 +262,61 @@ class GaussianDiffusion(object):
x0 = x0.clamp(-clamp, clamp) x0 = x0.clamp(-clamp, clamp)
return mu, var, log_var, x0 return mu, var, log_var, x0


@torch.no_grad()
def dpm_solver_sample_loop(self,
noise,
model,
skip_type,
order,
method,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
dpm_solver_timesteps=20,
t_start=None,
t_end=None,
lower_order_final=True,
denoise_to_zero=False,
solver_type='dpm_solver'):
r"""Sample using DPM-Solver-based method.
- condition_fn: for classifier-based guidance (guided-diffusion).
- guide_scale: for classifier-free guidance (glide/dalle-2).
Please check all the parameters in `dpm_solver.sample` before using.
"""
noise_schedule = NoiseScheduleVP(
schedule='discrete', betas=self.betas.float())
model_fn = model_wrapper_guided_diffusion(
model=model,
noise_schedule=noise_schedule,
var_type=self.var_type,
mean_type=self.mean_type,
model_kwargs=model_kwargs,
clamp=clamp,
percentile=percentile,
rescale_timesteps=self.rescale_timesteps,
num_timesteps=self.num_timesteps,
guide_scale=guide_scale,
condition_fn=condition_fn,
)
dpm_solver = DPM_Solver(
model_fn=model_fn,
noise_schedule=noise_schedule,
)
xt = dpm_solver.sample(
noise,
steps=dpm_solver_timesteps,
order=order,
skip_type=skip_type,
method=method,
solver_type=solver_type,
t_start=t_start,
t_end=t_end,
lower_order_final=lower_order_final,
denoise_to_zero=denoise_to_zero)
return xt

@torch.no_grad() @torch.no_grad()
def ddim_sample(self, def ddim_sample(self,
xt, xt,


+ 149
- 54
modelscope/models/multi_modal/diffusion/model.py View File

@@ -197,60 +197,155 @@ class DiffusionForTextToImageSynthesis(Model):
attention_mask=attention_mask) attention_mask=attention_mask)
context = context[-1] context = context[-1]


# generation
img = self.diffusion_generator.ddim_sample_loop(
noise=torch.randn(1, 3, 64, 64).to(self.device),
model=self.unet_generator,
model_kwargs=[{
'y': y,
'context': context,
'mask': attention_mask
}, {
'y': torch.zeros_like(y),
'context': torch.zeros_like(context),
'mask': attention_mask
}],
percentile=input.get('generator_percentile', 0.995),
guide_scale=input.get('generator_guide_scale', 5.0),
ddim_timesteps=input.get('generator_ddim_timesteps', 250),
eta=input.get('generator_ddim_eta', 0.0))

# upsampling (64->256)
if not input.get('debug', False):
img = F.interpolate(
img, scale_factor=4.0, mode='bilinear', align_corners=False)
img = self.diffusion_upsampler_256.ddim_sample_loop(
noise=torch.randn_like(img),
model=self.unet_upsampler_256,
model_kwargs=[{
'lx': img,
'lt': torch.zeros(1).to(self.device),
'y': y,
'context': context,
'mask': attention_mask
}, {
'lx': img,
'lt': torch.zeros(1).to(self.device),
'y': torch.zeros_like(y),
'context': torch.zeros_like(context),
'mask': torch.zeros_like(attention_mask)
}],
percentile=input.get('upsampler_256_percentile', 0.995),
guide_scale=input.get('upsampler_256_guide_scale', 5.0),
ddim_timesteps=input.get('upsampler_256_ddim_timesteps', 50),
eta=input.get('upsampler_256_ddim_eta', 0.0))

# upsampling (256->1024)
if not input.get('debug', False):
img = F.interpolate(
img, scale_factor=4.0, mode='bilinear', align_corners=False)
img = self.diffusion_upsampler_1024.ddim_sample_loop(
noise=torch.randn_like(img),
model=self.unet_upsampler_1024,
model_kwargs={'concat': img},
percentile=input.get('upsampler_1024_percentile', 0.995),
ddim_timesteps=input.get('upsampler_1024_ddim_timesteps', 20),
eta=input.get('upsampler_1024_ddim_eta', 0.0))
# choose a proper solver
solver = input.get('solver', 'dpm-solver')
if solver == 'dpm-solver':
# generation
img = self.diffusion_generator.dpm_solver_sample_loop(
noise=torch.randn(1, 3, 64, 64).to(self.device),
model=self.unet_generator,
model_kwargs=[{
'y': y,
'context': context,
'mask': attention_mask
}, {
'y': torch.zeros_like(y),
'context': torch.zeros_like(context),
'mask': attention_mask
}],
percentile=input.get('generator_percentile', 0.995),
guide_scale=input.get('generator_guide_scale', 5.0),
dpm_solver_timesteps=input.get('dpm_solver_timesteps', 20),
order=3,
skip_type='logSNR',
method='singlestep',
t_start=0.9946)

# upsampling (64->256)
if not input.get('debug', False):
img = F.interpolate(
img,
scale_factor=4.0,
mode='bilinear',
align_corners=False)
img = self.diffusion_upsampler_256.dpm_solver_sample_loop(
noise=torch.randn_like(img),
model=self.unet_upsampler_256,
model_kwargs=[{
'lx': img,
'lt': torch.zeros(1).to(self.device),
'y': y,
'context': context,
'mask': attention_mask
}, {
'lx': img,
'lt': torch.zeros(1).to(self.device),
'y': torch.zeros_like(y),
'context': torch.zeros_like(context),
'mask': torch.zeros_like(attention_mask)
}],
percentile=input.get('upsampler_256_percentile', 0.995),
guide_scale=input.get('upsampler_256_guide_scale', 5.0),
dpm_solver_timesteps=input.get('dpm_solver_timesteps', 20),
order=3,
skip_type='logSNR',
method='singlestep',
t_start=0.9946)

# upsampling (256->1024)
if not input.get('debug', False):
img = F.interpolate(
img,
scale_factor=4.0,
mode='bilinear',
align_corners=False)
img = self.diffusion_upsampler_1024.dpm_solver_sample_loop(
noise=torch.randn_like(img),
model=self.unet_upsampler_256,
model_kwargs=[{
'lx': img,
'lt': torch.zeros(1).to(self.device),
'y': y,
'context': context,
'mask': attention_mask
}, {
'lx': img,
'lt': torch.zeros(1).to(self.device),
'y': torch.zeros_like(y),
'context': torch.zeros_like(context),
'mask': torch.zeros_like(attention_mask)
}],
percentile=input.get('upsampler_256_percentile', 0.995),
guide_scale=input.get('upsampler_256_guide_scale', 5.0),
dpm_solver_timesteps=input.get('dpm_solver_timesteps', 10),
order=3,
skip_type='logSNR',
method='singlestep',
t_start=None)
elif solver == 'ddim':
# generation
img = self.diffusion_generator.ddim_sample_loop(
noise=torch.randn(1, 3, 64, 64).to(self.device),
model=self.unet_generator,
model_kwargs=[{
'y': y,
'context': context,
'mask': attention_mask
}, {
'y': torch.zeros_like(y),
'context': torch.zeros_like(context),
'mask': attention_mask
}],
percentile=input.get('generator_percentile', 0.995),
guide_scale=input.get('generator_guide_scale', 5.0),
ddim_timesteps=input.get('generator_ddim_timesteps', 250),
eta=input.get('generator_ddim_eta', 0.0))

# upsampling (64->256)
if not input.get('debug', False):
img = F.interpolate(
img,
scale_factor=4.0,
mode='bilinear',
align_corners=False)
img = self.diffusion_upsampler_256.ddim_sample_loop(
noise=torch.randn_like(img),
model=self.unet_upsampler_256,
model_kwargs=[{
'lx': img,
'lt': torch.zeros(1).to(self.device),
'y': y,
'context': context,
'mask': attention_mask
}, {
'lx': img,
'lt': torch.zeros(1).to(self.device),
'y': torch.zeros_like(y),
'context': torch.zeros_like(context),
'mask': torch.zeros_like(attention_mask)
}],
percentile=input.get('upsampler_256_percentile', 0.995),
guide_scale=input.get('upsampler_256_guide_scale', 5.0),
ddim_timesteps=input.get('upsampler_256_ddim_timesteps', 50),
eta=input.get('upsampler_256_ddim_eta', 0.0))

# upsampling (256->1024)
if not input.get('debug', False):
img = F.interpolate(
img,
scale_factor=4.0,
mode='bilinear',
align_corners=False)
img = self.diffusion_upsampler_1024.ddim_sample_loop(
noise=torch.randn_like(img),
model=self.unet_upsampler_1024,
model_kwargs={'concat': img},
percentile=input.get('upsampler_1024_percentile', 0.995),
ddim_timesteps=input.get('upsampler_1024_ddim_timesteps', 20),
eta=input.get('upsampler_1024_ddim_eta', 0.0))
else:
raise ValueError(
'currently only supports "ddim" and "dpm-solve" solvers')


# output # output
img = img.clamp(-1, 1).add(1).mul(127.5).squeeze(0).permute( img = img.clamp(-1, 1).add(1).mul(127.5).squeeze(0).permute(


+ 1075
- 0
modelscope/models/multi_modal/dpm_solver_pytorch.py
File diff suppressed because it is too large
View File


+ 58
- 0
modelscope/models/multi_modal/multi_stage_diffusion/gaussian_diffusion.py View File

@@ -6,6 +6,9 @@ import math


import torch import torch


from modelscope.models.multi_modal.dpm_solver_pytorch import (
DPM_Solver, NoiseScheduleVP, model_wrapper, model_wrapper_guided_diffusion)

__all__ = ['GaussianDiffusion', 'beta_schedule'] __all__ = ['GaussianDiffusion', 'beta_schedule']




@@ -279,6 +282,61 @@ class GaussianDiffusion(object):
x0 = x0.clamp(-clamp, clamp) x0 = x0.clamp(-clamp, clamp)
return mu, var, log_var, x0 return mu, var, log_var, x0


@torch.no_grad()
def dpm_solver_sample_loop(self,
noise,
model,
skip_type,
order,
method,
model_kwargs={},
clamp=None,
percentile=None,
condition_fn=None,
guide_scale=None,
dpm_solver_timesteps=20,
t_start=None,
t_end=None,
lower_order_final=True,
denoise_to_zero=False,
solver_type='dpm_solver'):
r"""Sample using DPM-Solver-based method.
- condition_fn: for classifier-based guidance (guided-diffusion).
- guide_scale: for classifier-free guidance (glide/dalle-2).
Please check all the parameters in `dpm_solver.sample` before using.
"""
noise_schedule = NoiseScheduleVP(
schedule='discrete', betas=self.betas.float())
model_fn = model_wrapper_guided_diffusion(
model=model,
noise_schedule=noise_schedule,
var_type=self.var_type,
mean_type=self.mean_type,
model_kwargs=model_kwargs,
clamp=clamp,
percentile=percentile,
rescale_timesteps=self.rescale_timesteps,
num_timesteps=self.num_timesteps,
guide_scale=guide_scale,
condition_fn=condition_fn,
)
dpm_solver = DPM_Solver(
model_fn=model_fn,
noise_schedule=noise_schedule,
)
xt = dpm_solver.sample(
noise,
steps=dpm_solver_timesteps,
order=order,
skip_type=skip_type,
method=method,
solver_type=solver_type,
t_start=t_start,
t_end=t_end,
lower_order_final=lower_order_final,
denoise_to_zero=denoise_to_zero)
return xt

@torch.no_grad() @torch.no_grad()
def ddim_sample(self, def ddim_sample(self,
xt, xt,


+ 156
- 65
modelscope/models/multi_modal/multi_stage_diffusion/model.py View File

@@ -95,7 +95,8 @@ class UnCLIP(nn.Module):
eta_prior=0.0, eta_prior=0.0,
eta_64=0.0, eta_64=0.0,
eta_256=0.0, eta_256=0.0,
eta_1024=0.0):
eta_1024=0.0,
solver='dpm-solver'):
device = next(self.parameters()).device device = next(self.parameters()).device


# check params # check params
@@ -141,71 +142,160 @@ class UnCLIP(nn.Module):


# synthesis # synthesis
with amp.autocast(enabled=True): with amp.autocast(enabled=True):
# prior
x0 = self.prior_diffusion.ddim_sample_loop(
noise=torch.randn_like(y),
model=self.prior,
model_kwargs=[{
'y': y
}, {
'y': zero_y
}],
guide_scale=guide_prior,
ddim_timesteps=timesteps_prior,
eta=eta_prior)
# choose a proper solver
if solver == 'dpm-solver':
# prior
x0 = self.prior_diffusion.dpm_solver_sample_loop(
noise=torch.randn_like(y),
model=self.prior,
model_kwargs=[{
'y': y
}, {
'y': zero_y
}],
guide_scale=guide_prior,
dpm_solver_timesteps=timesteps_prior,
order=3,
skip_type='logSNR',
method='singlestep',
t_start=0.9946)


# decoder
imgs64 = self.decoder_diffusion.ddim_sample_loop(
noise=torch.randn(batch_size, 3, 64, 64).to(device),
model=self.decoder,
model_kwargs=[{
'y': x0
}, {
'y': torch.zeros_like(x0)
}],
guide_scale=guide_64,
percentile=0.995,
ddim_timesteps=timesteps_64,
eta=eta_64).clamp_(-1, 1)
# decoder
imgs64 = self.decoder_diffusion.dpm_solver_sample_loop(
noise=torch.randn(batch_size, 3, 64, 64).to(device),
model=self.decoder,
model_kwargs=[{
'y': x0
}, {
'y': torch.zeros_like(x0)
}],
guide_scale=guide_64,
percentile=0.995,
dpm_solver_timesteps=timesteps_64,
order=3,
skip_type='logSNR',
method='singlestep',
t_start=0.9946).clamp_(-1, 1)


# upsampler256
imgs256 = F.interpolate(
imgs64, scale_factor=4.0, mode='bilinear', align_corners=False)
imgs256 = self.upsampler256_diffusion.ddim_sample_loop(
noise=torch.randn_like(imgs256),
model=self.upsampler256,
model_kwargs=[{
'y': y,
'concat': imgs256
}, {
'y': zero_y,
'concat': imgs256
}],
guide_scale=guide_256,
percentile=0.995,
ddim_timesteps=timesteps_256,
eta=eta_256).clamp_(-1, 1)
# upsampler256
imgs256 = F.interpolate(
imgs64,
scale_factor=4.0,
mode='bilinear',
align_corners=False)
imgs256 = self.upsampler256_diffusion.dpm_solver_sample_loop(
noise=torch.randn_like(imgs256),
model=self.upsampler256,
model_kwargs=[{
'y': y,
'concat': imgs256
}, {
'y': zero_y,
'concat': imgs256
}],
guide_scale=guide_256,
percentile=0.995,
dpm_solver_timesteps=timesteps_256,
order=3,
skip_type='logSNR',
method='singlestep',
t_start=0.9946).clamp_(-1, 1)


# upsampler1024
imgs1024 = F.interpolate(
imgs256,
scale_factor=4.0,
mode='bilinear',
align_corners=False)
imgs1024 = self.upsampler1024_diffusion.ddim_sample_loop(
noise=torch.randn_like(imgs1024),
model=self.upsampler1024,
model_kwargs=[{
'y': y,
'concat': imgs1024
}, {
'y': zero_y,
'concat': imgs1024
}],
guide_scale=guide_1024,
percentile=0.995,
ddim_timesteps=timesteps_1024,
eta=eta_1024).clamp_(-1, 1)
# upsampler1024
imgs1024 = F.interpolate(
imgs256,
scale_factor=4.0,
mode='bilinear',
align_corners=False)
imgs1024 = self.upsampler1024_diffusion.dpm_solver_sample_loop(
noise=torch.randn_like(imgs1024),
model=self.upsampler1024,
model_kwargs=[{
'y': y,
'concat': imgs1024
}, {
'y': zero_y,
'concat': imgs1024
}],
guide_scale=guide_1024,
percentile=0.995,
dpm_solver_timesteps=timesteps_1024,
order=3,
skip_type='logSNR',
method='singlestep',
t_start=None).clamp_(-1, 1)
elif solver == 'ddim':
# prior
x0 = self.prior_diffusion.ddim_sample_loop(
noise=torch.randn_like(y),
model=self.prior,
model_kwargs=[{
'y': y
}, {
'y': zero_y
}],
guide_scale=guide_prior,
ddim_timesteps=timesteps_prior,
eta=eta_prior)

# decoder
imgs64 = self.decoder_diffusion.ddim_sample_loop(
noise=torch.randn(batch_size, 3, 64, 64).to(device),
model=self.decoder,
model_kwargs=[{
'y': x0
}, {
'y': torch.zeros_like(x0)
}],
guide_scale=guide_64,
percentile=0.995,
ddim_timesteps=timesteps_64,
eta=eta_64).clamp_(-1, 1)

# upsampler256
imgs256 = F.interpolate(
imgs64,
scale_factor=4.0,
mode='bilinear',
align_corners=False)
imgs256 = self.upsampler256_diffusion.ddim_sample_loop(
noise=torch.randn_like(imgs256),
model=self.upsampler256,
model_kwargs=[{
'y': y,
'concat': imgs256
}, {
'y': zero_y,
'concat': imgs256
}],
guide_scale=guide_256,
percentile=0.995,
ddim_timesteps=timesteps_256,
eta=eta_256).clamp_(-1, 1)

# upsampler1024
imgs1024 = F.interpolate(
imgs256,
scale_factor=4.0,
mode='bilinear',
align_corners=False)
imgs1024 = self.upsampler1024_diffusion.ddim_sample_loop(
noise=torch.randn_like(imgs1024),
model=self.upsampler1024,
model_kwargs=[{
'y': y,
'concat': imgs1024
}, {
'y': zero_y,
'concat': imgs1024
}],
guide_scale=guide_1024,
percentile=0.995,
ddim_timesteps=timesteps_1024,
eta=eta_1024).clamp_(-1, 1)
else:
raise ValueError(
'currently only supports "ddim" and "dpm-solve" solvers')


# output ([B, C, H, W] within range [0, 1]) # output ([B, C, H, W] within range [0, 1])
imgs1024 = imgs1024.add_(1).mul_(255 / 2.0).permute(0, 2, 3, 1).cpu() imgs1024 = imgs1024.add_(1).mul_(255 / 2.0).permute(0, 2, 3, 1).cpu()
@@ -245,7 +335,7 @@ class MultiStageDiffusionForTextToImageSynthesis(TorchModel):
if 'text' not in input: if 'text' not in input:
raise ValueError('input should contain "text", but not found') raise ValueError('input should contain "text", but not found')


# ddim sampling
# sampling
imgs = self.model.synthesis( imgs = self.model.synthesis(
text=input.get('text'), text=input.get('text'),
tokenizer=input.get('tokenizer', 'clip'), tokenizer=input.get('tokenizer', 'clip'),
@@ -261,6 +351,7 @@ class MultiStageDiffusionForTextToImageSynthesis(TorchModel):
eta_prior=input.get('eta_prior', 0.0), eta_prior=input.get('eta_prior', 0.0),
eta_64=input.get('eta_64', 0.0), eta_64=input.get('eta_64', 0.0),
eta_256=input.get('eta_256', 0.0), eta_256=input.get('eta_256', 0.0),
eta_1024=input.get('eta_1024', 0.0))
eta_1024=input.get('eta_1024', 0.0),
solver=input.get('solver', 'dpm-solver'))
imgs = [np.array(u)[..., ::-1] for u in imgs] imgs = [np.array(u)[..., ::-1] for u in imgs]
return imgs return imgs

+ 10
- 0
tests/pipelines/test_text_to_image_synthesis.py View File

@@ -51,6 +51,16 @@ class TextToImageSynthesisTest(unittest.TestCase, DemoCompatibilityCheck):
self.test_text)[OutputKeys.OUTPUT_IMG] self.test_text)[OutputKeys.OUTPUT_IMG]
print(np.sum(np.abs(img))) print(np.sum(np.abs(img)))


@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_model_from_modelhub_dpm_solver(self):
test_text.update({'solver': 'dpm-solver'})
model = Model.from_pretrained(self.model_id)
pipe_line_text_to_image_synthesis = pipeline(
task=Tasks.text_to_image_synthesis, model=model)
img = pipe_line_text_to_image_synthesis(
self.test_text)[OutputKeys.OUTPUT_IMG]
print(np.sum(np.abs(img)))

@unittest.skip('demo compatibility test is only enabled on a needed-basis') @unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self): def test_demo_compatibility(self):
self.compatibility_check() self.compatibility_check()


Loading…
Cancel
Save