diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index f5e4775900de..846da6c76d59 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -789,6 +789,8 @@ def get_timesteps(self, num_inference_steps, strength, device): t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index bc6133c8b2d1..bc985beae69d 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -705,6 +705,8 @@ def get_timesteps(self, num_inference_steps, strength, device): t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index dda2f207b90a..ca6b5165fefb 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -871,6 +871,8 @@ def get_timesteps(self, num_inference_steps, strength, device): t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py index 9d2b3ca8abaf..da2f4ba9b6e9 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py @@ -566,6 +566,8 @@ def get_timesteps(self, num_inference_steps, strength, device): t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py index 4daa1c07f0c6..449b6d88b9de 100644 --- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py @@ -536,6 +536,8 @@ def get_timesteps(self, num_inference_steps, strength, device): t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index 4146a35fb909..509b5ab34bde 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -634,6 +634,8 @@ def get_timesteps(self, num_inference_steps, strength, device): t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 565544a0fef4..fda56088b916 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -906,6 +906,8 @@ def get_timesteps(self, num_inference_steps, strength, device): t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index e431fee7bdb0..440a972ff8e0 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -467,6 +467,8 @@ def get_timesteps(self, num_inference_steps, strength, device): t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index f78cd383b83a..a9b04b493c7e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -659,6 +659,8 @@ def get_timesteps(self, num_inference_steps, strength, device): t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 5d77341511a3..111a70aa5c09 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -859,6 +859,8 @@ def get_timesteps(self, num_inference_steps, strength, device): t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py index d0d132555e69..82e91e3565ea 100644 --- a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +++ b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py @@ -754,6 +754,8 @@ def get_timesteps(self, num_inference_steps, strength, device): t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index c781e490caae..342a81b81a2e 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -554,6 +554,8 @@ def get_timesteps(self, num_inference_steps, strength, device): t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py index e7e0dcbdc31e..9b672a74fc26 100644 --- a/src/diffusers/schedulers/scheduling_consistency_models.py +++ b/src/diffusers/schedulers/scheduling_consistency_models.py @@ -98,15 +98,9 @@ def __init__( self.custom_timesteps = False self.is_scale_input_called = False self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - def index_for_timestep(self, timestep, schedule_timesteps=None): - if schedule_timesteps is None: - schedule_timesteps = self.timesteps - - indices = (schedule_timesteps == timestep).nonzero() - return indices.item() - @property def step_index(self): """ @@ -114,6 +108,24 @@ def step_index(self): """ return self._step_index + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: @@ -231,6 +243,7 @@ def set_timesteps( self.timesteps = torch.from_numpy(timesteps).to(device=device) self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Modified _convert_to_karras implementation that takes in ramp as argument @@ -280,23 +293,29 @@ def get_scalings_for_boundary_condition(self, sigma): c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 return c_skip, c_out - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps - index_candidates = (self.timesteps == timestep).nonzero() + indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(index_candidates) > 1: - step_index = index_candidates[1] - else: - step_index = index_candidates[0] + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() - self._step_index = step_index.item() + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index def step( self, @@ -412,7 +431,11 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + else: + step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index e8bd5f8f68d4..a0831d80f71b 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -187,6 +187,7 @@ def __init__( self.model_outputs = [None] * solver_order self.lower_order_nums = 0 self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property @@ -196,6 +197,24 @@ def step_index(self): """ return self._step_index + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -255,6 +274,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic # add an index counter for schedulers that allow duplicated timesteps self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample @@ -620,11 +640,12 @@ def ind_fn(t, b, c, d): else: raise NotImplementedError("only support log-rho multistep deis now") - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps - index_candidates = (self.timesteps == timestep).nonzero() + index_candidates = (schedule_timesteps == timestep).nonzero() if len(index_candidates) == 0: step_index = len(self.timesteps) - 1 @@ -637,7 +658,20 @@ def _init_step_index(self, timestep): else: step_index = index_candidates[0].item() - self._step_index = step_index + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index def step( self, @@ -736,16 +770,11 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [] - for timestep in timesteps: - index_candidates = (schedule_timesteps == timestep).nonzero() - if len(index_candidates) == 0: - step_index = len(schedule_timesteps) - 1 - elif len(index_candidates) > 1: - step_index = index_candidates[1].item() - else: - step_index = index_candidates[0].item() - step_indices.append(step_index) + # begin_index is None when the scheduler is used for training + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + else: + step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index d70d4eec9b3e..bfb0d943ee2c 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -227,6 +227,7 @@ def __init__( self.model_outputs = [None] * solver_order self.lower_order_nums = 0 self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property @@ -236,6 +237,23 @@ def step_index(self): """ return self._step_index + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -311,6 +329,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # add an index counter for schedulers that allow duplicated timesteps self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample @@ -792,11 +811,11 @@ def multistep_dpm_solver_third_order_update( ) return x_t - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps - index_candidates = (self.timesteps == timestep).nonzero() + index_candidates = (schedule_timesteps == timestep).nonzero() if len(index_candidates) == 0: step_index = len(self.timesteps) - 1 @@ -809,7 +828,19 @@ def _init_step_index(self, timestep): else: step_index = index_candidates[0].item() - self._step_index = step_index + return step_index + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index def step( self, @@ -920,16 +951,11 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [] - for timestep in timesteps: - index_candidates = (schedule_timesteps == timestep).nonzero() - if len(index_candidates) == 0: - step_index = len(schedule_timesteps) - 1 - elif len(index_candidates) > 1: - step_index = index_candidates[1].item() - else: - step_index = index_candidates[0].item() - step_indices.append(step_index) + # begin_index is None when the scheduler is used for training + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + else: + step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py index 03fc3677d07f..089cfc0d988f 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -767,7 +767,6 @@ def multistep_dpm_solver_third_order_update( ) return x_t - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index def _init_step_index(self, timestep): if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) @@ -879,7 +878,6 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch """ return sample - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index 20c294f95bd6..c51cd3f440a3 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -13,7 +13,6 @@ # limitations under the License. import math -from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np @@ -198,9 +197,10 @@ def __init__( self.noise_sampler = None self.noise_sampler_seed = noise_sampler_seed self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -211,31 +211,18 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(self._index_counter) == 0: - pos = 1 if len(indices) > 1 else 0 - else: - timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep - pos = self._index_counter[timestep_int] + pos = 1 if len(indices) > 1 else 0 return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - - index_candidates = (self.timesteps == timestep).nonzero() - - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(index_candidates) > 1: - step_index = index_candidates[1] + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) else: - step_index = index_candidates[0] - - self._step_index = step_index.item() + self._step_index = self._begin_index @property def init_noise_sigma(self): @@ -252,6 +239,24 @@ def step_index(self): """ return self._step_index + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + def scale_model_input( self, sample: torch.FloatTensor, @@ -348,13 +353,10 @@ def set_timesteps( self.mid_point_sigma = None self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.noise_sampler = None - # for exp beta schedules, such as the one for `pipeline_shap_e.py` - # we need an index counter - self._index_counter = defaultdict(int) - def _second_order_timesteps(self, sigmas, log_sigmas): def sigma_fn(_t): return np.exp(-_t) @@ -444,10 +446,6 @@ def step( if self.step_index is None: self._init_step_index(timestep) - # advance index counter by 1 - timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep - self._index_counter[timestep_int] += 1 - # Create a noise sampler if it hasn't been created yet if self.noise_sampler is None: min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max() @@ -527,7 +525,7 @@ def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor: return SchedulerOutput(prev_sample=prev_sample) - # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, @@ -544,7 +542,11 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + else: + step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index f664374a4238..e22085da74f5 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -210,6 +210,7 @@ def __init__( self.sample = None self.order_list = self.get_order_list(num_train_timesteps) self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication def get_order_list(self, num_inference_steps: int) -> List[int]: @@ -253,6 +254,24 @@ def step_index(self): """ return self._step_index + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -315,6 +334,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic # add an index counter for schedulers that allow duplicated timesteps self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample @@ -813,11 +833,12 @@ def singlestep_dpm_solver_update( else: raise ValueError(f"Order must be 1, 2, 3, got {order}") - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps - index_candidates = (self.timesteps == timestep).nonzero() + index_candidates = (schedule_timesteps == timestep).nonzero() if len(index_candidates) == 0: step_index = len(self.timesteps) - 1 @@ -830,7 +851,20 @@ def _init_step_index(self, timestep): else: step_index = index_candidates[0].item() - self._step_index = step_index + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index def step( self, @@ -925,16 +959,11 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [] - for timestep in timesteps: - index_candidates = (schedule_timesteps == timestep).nonzero() - if len(index_candidates) == 0: - step_index = len(schedule_timesteps) - 1 - elif len(index_candidates) > 1: - step_index = index_candidates[1].item() - else: - step_index = index_candidates[0].item() - step_indices.append(step_index) + # begin_index is None when the scheduler is used for training + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + else: + step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index acad67847237..35fb22c9fdab 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -216,6 +216,7 @@ def __init__( self.is_scale_input_called = False self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property @@ -233,6 +234,24 @@ def step_index(self): """ return self._step_index + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: @@ -300,25 +319,32 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.timesteps = torch.from_numpy(timesteps).to(device=device) self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps - index_candidates = (self.timesteps == timestep).nonzero() + indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(index_candidates) > 1: - step_index = index_candidates[1] - else: - step_index = index_candidates[0] + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() - self._step_index = step_index.item() + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index def step( self, @@ -440,7 +466,11 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + else: + step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 6ed28f410aea..c5e858e545be 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -237,6 +237,7 @@ def __init__( self.use_karras_sigmas = use_karras_sigmas self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property @@ -255,6 +256,24 @@ def step_index(self): """ return self._step_index + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: @@ -342,6 +361,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication def _sigma_to_t(self, sigma, log_sigmas): @@ -393,22 +413,27 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps - index_candidates = (self.timesteps == timestep).nonzero() + indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(index_candidates) > 1: - step_index = index_candidates[1] - else: - step_index = index_candidates[0] + pos = 1 if len(indices) > 1 else 0 - self._step_index = step_index.item() + return indices[pos].item() + + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index def step( self, @@ -538,7 +563,11 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + else: + step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index a1ea18dcf168..b1877bae4727 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -13,7 +13,6 @@ # limitations under the License. import math -from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np @@ -148,8 +147,10 @@ def __init__( self.use_karras_sigmas = use_karras_sigmas self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps @@ -160,11 +161,7 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(self._index_counter) == 0: - pos = 1 if len(indices) > 1 else 0 - else: - timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep - pos = self._index_counter[timestep_int] + pos = 1 if len(indices) > 1 else 0 return indices[pos].item() @@ -183,6 +180,24 @@ def step_index(self): """ return self._step_index + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + def scale_model_input( self, sample: torch.FloatTensor, @@ -270,13 +285,9 @@ def set_timesteps( self.dt = None self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - # (YiYi Notes: keep this for now since we are keeping add_noise function which use index_for_timestep) - # for exp beta schedules, such as the one for `pipeline_shap_e.py` - # we need an index counter - self._index_counter = defaultdict(int) - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): # get log sigma @@ -333,21 +344,12 @@ def state_in_first_order(self): # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - - index_candidates = (self.timesteps == timestep).nonzero() - - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(index_candidates) > 1: - step_index = index_candidates[1] + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) else: - step_index = index_candidates[0] - - self._step_index = step_index.item() + self._step_index = self._begin_index def step( self, @@ -378,11 +380,6 @@ def step( if self.step_index is None: self._init_step_index(timestep) - # (YiYi notes: keep this for now since we are keeping the add_noise method) - # advance index counter by 1 - timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep - self._index_counter[timestep_int] += 1 - if self.state_in_first_order: sigma = self.sigmas[self.step_index] sigma_next = self.sigmas[self.step_index + 1] @@ -453,6 +450,7 @@ def step( return SchedulerOutput(prev_sample=prev_sample) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, @@ -469,7 +467,11 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + else: + step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_ipndm.py b/src/diffusers/schedulers/scheduling_ipndm.py index aeebd029a441..4025bad1a327 100644 --- a/src/diffusers/schedulers/scheduling_ipndm.py +++ b/src/diffusers/schedulers/scheduling_ipndm.py @@ -56,6 +56,7 @@ def __init__( # running values self.ets = [] self._step_index = None + self._begin_index = None @property def step_index(self): @@ -64,6 +65,24 @@ def step_index(self): """ return self._step_index + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -90,24 +109,31 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.ets = [] self._step_index = None + self._begin_index = None - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps - index_candidates = (self.timesteps == timestep).nonzero() + indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(index_candidates) > 1: - step_index = index_candidates[1] - else: - step_index = index_candidates[0] + pos = 1 if len(indices) > 1 else 0 - self._step_index = step_index.item() + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index def step( self, diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 4a1cdb561cea..5c1934c1b077 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -13,7 +13,6 @@ # limitations under the License. import math -from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np @@ -140,27 +139,9 @@ def __init__( # set all values self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): - if schedule_timesteps is None: - schedule_timesteps = self.timesteps - - indices = (schedule_timesteps == timestep).nonzero() - - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(self._index_counter) == 0: - pos = 1 if len(indices) > 1 else 0 - else: - timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep - pos = self._index_counter[timestep_int] - - return indices[pos].item() - @property def init_noise_sigma(self): # standard deviation of the initial noise distribution @@ -176,6 +157,24 @@ def step_index(self): """ return self._step_index + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + def scale_model_input( self, sample: torch.FloatTensor, @@ -295,11 +294,8 @@ def set_timesteps( self.sample = None - # for exp beta schedules, such as the one for `pipeline_shap_e.py` - # we need an index counter - self._index_counter = defaultdict(int) - self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t @@ -356,23 +352,29 @@ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) def state_in_first_order(self): return self.sample is None - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps - index_candidates = (self.timesteps == timestep).nonzero() + indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(index_candidates) > 1: - step_index = index_candidates[1] - else: - step_index = index_candidates[0] + pos = 1 if len(indices) > 1 else 0 - self._step_index = step_index.item() + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index def step( self, @@ -406,10 +408,6 @@ def step( if self.step_index is None: self._init_step_index(timestep) - # advance index counter by 1 - timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep - self._index_counter[timestep_int] += 1 - if self.state_in_first_order: sigma = self.sigmas[self.step_index] sigma_interpol = self.sigmas_interpol[self.step_index] @@ -478,7 +476,7 @@ def step( return SchedulerOutput(prev_sample=prev_sample) - # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, @@ -495,7 +493,11 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + else: + step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 57062c0d3586..7c800e4e68b2 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -13,7 +13,6 @@ # limitations under the License. import math -from collections import defaultdict from typing import List, Optional, Tuple, Union import numpy as np @@ -140,27 +139,9 @@ def __init__( self.set_timesteps(num_train_timesteps, None, num_train_timesteps) self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication - # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep - def index_for_timestep(self, timestep, schedule_timesteps=None): - if schedule_timesteps is None: - schedule_timesteps = self.timesteps - - indices = (schedule_timesteps == timestep).nonzero() - - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(self._index_counter) == 0: - pos = 1 if len(indices) > 1 else 0 - else: - timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep - pos = self._index_counter[timestep_int] - - return indices[pos].item() - @property def init_noise_sigma(self): # standard deviation of the initial noise distribution @@ -176,6 +157,24 @@ def step_index(self): """ return self._step_index + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + def scale_model_input( self, sample: torch.FloatTensor, @@ -280,34 +279,37 @@ def set_timesteps( self.sample = None - # for exp beta schedules, such as the one for `pipeline_shap_e.py` - # we need an index counter - self._index_counter = defaultdict(int) - self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def state_in_first_order(self): return self.sample is None - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps - index_candidates = (self.timesteps == timestep).nonzero() + indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(index_candidates) > 1: - step_index = index_candidates[1] - else: - step_index = index_candidates[0] + pos = 1 if len(indices) > 1 else 0 - self._step_index = step_index.item() + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): @@ -388,10 +390,6 @@ def step( if self.step_index is None: self._init_step_index(timestep) - # advance index counter by 1 - timestep_int = timestep.cpu().item() if torch.is_tensor(timestep) else timestep - self._index_counter[timestep_int] += 1 - if self.state_in_first_order: sigma = self.sigmas[self.step_index] sigma_interpol = self.sigmas_interpol[self.step_index + 1] @@ -453,7 +451,7 @@ def step( return SchedulerOutput(prev_sample=prev_sample) - # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, @@ -470,7 +468,11 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + else: + step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_lcm.py b/src/diffusers/schedulers/scheduling_lcm.py index a54f78423d73..1156c2634e31 100644 --- a/src/diffusers/schedulers/scheduling_lcm.py +++ b/src/diffusers/schedulers/scheduling_lcm.py @@ -250,29 +250,54 @@ def __init__( self.custom_timesteps = False self._step_index = None + self._begin_index = None - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps - index_candidates = (self.timesteps == timestep).nonzero() + indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(index_candidates) > 1: - step_index = index_candidates[1] - else: - step_index = index_candidates[0] + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() - self._step_index = step_index.item() + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index @property def step_index(self): return self._step_index + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the @@ -462,6 +487,7 @@ def set_timesteps( self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.long) self._step_index = None + self._begin_index = None def get_scalings_for_boundary_condition_discrete(self, timestep): self.sigma_data = 0.5 # Default: 0.5 diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index f5f52b06bd43..02f78014d1f7 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -168,6 +168,7 @@ def __init__( self.is_scale_input_called = False self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property @@ -185,6 +186,24 @@ def step_index(self): """ return self._step_index + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: @@ -280,27 +299,34 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.sigmas = torch.from_numpy(sigmas).to(device=device) self.timesteps = torch.from_numpy(timesteps).to(device=device) self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication self.derivatives = [] - # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps - index_candidates = (self.timesteps == timestep).nonzero() + indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) - if len(index_candidates) > 1: - step_index = index_candidates[1] - else: - step_index = index_candidates[0] + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() - self._step_index = step_index.item() + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t def _sigma_to_t(self, sigma, log_sigmas): @@ -434,7 +460,11 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + else: + step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py index 13e3c76cf5b4..6a07cd082a47 100644 --- a/src/diffusers/schedulers/scheduling_sasolver.py +++ b/src/diffusers/schedulers/scheduling_sasolver.py @@ -212,6 +212,7 @@ def __init__( self.lower_order_nums = 0 self.last_sample = None self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property @@ -221,6 +222,24 @@ def step_index(self): """ return self._step_index + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -283,6 +302,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc # add an index counter for schedulers that allow duplicated timesteps self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample @@ -925,11 +945,12 @@ def stochastic_adams_moulton_update( x_t = x_t.to(x.dtype) return x_t - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps - index_candidates = (self.timesteps == timestep).nonzero() + index_candidates = (schedule_timesteps == timestep).nonzero() if len(index_candidates) == 0: step_index = len(self.timesteps) - 1 @@ -942,7 +963,20 @@ def _init_step_index(self, timestep): else: step_index = index_candidates[0].item() - self._step_index = step_index + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index def step( self, diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py index 1223213c69f3..e556093ee91b 100644 --- a/src/diffusers/schedulers/scheduling_unipc_multistep.py +++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py @@ -198,6 +198,7 @@ def __init__( self.solver_p = solver_p self.last_sample = None self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property @@ -207,6 +208,24 @@ def step_index(self): """ return self._step_index + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -269,6 +288,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic # add an index counter for schedulers that allow duplicated timesteps self._step_index = None + self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample @@ -698,11 +718,12 @@ def multistep_uni_c_bh_update( x_t = x_t.to(x.dtype) return x_t - def _init_step_index(self, timestep): - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps - index_candidates = (self.timesteps == timestep).nonzero() + index_candidates = (schedule_timesteps == timestep).nonzero() if len(index_candidates) == 0: step_index = len(self.timesteps) - 1 @@ -715,7 +736,20 @@ def _init_step_index(self, timestep): else: step_index = index_candidates[0].item() - self._step_index = step_index + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index def step( self, @@ -830,16 +864,11 @@ def add_noise( schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) - step_indices = [] - for timestep in timesteps: - index_candidates = (schedule_timesteps == timestep).nonzero() - if len(index_candidates) == 0: - step_index = len(schedule_timesteps) - 1 - elif len(index_candidates) > 1: - step_index = index_candidates[1].item() - else: - step_index = index_candidates[0].item() - step_indices.append(step_index) + # begin_index is None when the scheduler is used for training + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + else: + step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape):