@@ -544,7 +544,7 @@ def __init__(
544544
545545 assert scale <= 1 , 'scale must be less than or equal to 1'
546546 self .scale = scale
547- self .normalize_img_variance = normalize_img_variance if scale < 1 else identity
547+ self .maybe_normalize_img_variance = normalize_img_variance if scale < 1 else identity
548548
549549 # gamma schedules
550550
@@ -596,8 +596,8 @@ def ddpm_sample(self, shape, time_difference = None):
596596
597597 # get predicted x0
598598
599- img = self .normalize_img_variance (img )
600- model_output , last_latents = self .model (img , noise_cond , x_start , last_latents , return_latents = True )
599+ maybe_normalized_img = self .maybe_normalize_img_variance (img )
600+ model_output , last_latents = self .model (maybe_normalized_img , noise_cond , x_start , last_latents , return_latents = True )
601601
602602 # get log(snr)
603603
@@ -675,8 +675,8 @@ def ddim_sample(self, shape, time_difference = None):
675675
676676 # predict x0
677677
678- img = self .normalize_img_variance (img )
679- model_output , last_latents = self .model (img , times , x_start , last_latents , return_latents = True )
678+ maybe_normalized_img = self .maybe_normalize_img_variance (img )
679+ model_output , last_latents = self .model (maybe_normalized_img , times , x_start , last_latents , return_latents = True )
680680
681681 # calculate x0 and noise
682682
@@ -732,7 +732,7 @@ def forward(self, img, *args, **kwargs):
732732
733733 noised_img = alpha * img + sigma * noise
734734
735- noised_img = self .normalize_img_variance (noised_img )
735+ noised_img = self .maybe_normalize_img_variance (noised_img )
736736
737737 # in the paper, they had to use a really high probability of latent self conditioning, up to 90% of the time
738738 # slight drawback
0 commit comments