Skip to content

Commit

Permalink
Combine the two classifier-free guidance model outputs into a single …
Browse files Browse the repository at this point in the history
…batch
  • Loading branch information
crowsonkb committed Apr 5, 2022
1 parent f0c4e09 commit 66df437
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
11 changes: 7 additions & 4 deletions ldm/models/diffusion/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,14 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None):
b, *_, device = *x.shape, x.device
e_t = self.model.apply_model(x, t, c)

if unconditional_guidance_scale != 1.:
assert unconditional_conditioning is not None
e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning)
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

if score_corrector is not None:
Expand Down
12 changes: 7 additions & 5 deletions ldm/models/diffusion/plms.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,13 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
b, *_, device = *x.shape, x.device

def get_model_output(x, t):
e_t = self.model.apply_model(x, t, c)

if unconditional_guidance_scale != 1.:
assert unconditional_conditioning is not None
e_t_uncond = self.model.apply_model(x, t, unconditional_conditioning)
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

if score_corrector is not None:
Expand Down

0 comments on commit 66df437

Please sign in to comment.