Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Allow users to enforce zero terminal SNR in schedulers #404

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions generative/networks/schedulers/pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,20 @@ class PNDMPredictionType(StrEnum):
EPSILON = "epsilon"
V_PREDICTION = "v_prediction"

class PNDMTimestepSpacing(StrEnum):
"""
Set of valid inference timestep spacing names for the PNDM scheduler's `timestep_spacing` argument.

See Table 2. of "Common Diffusion Noise Schedules and Sample Steps are Flawed" https://arxiv.org/abs/2305.08891

leading: first step is always included.
linspace: first and last step are always included.
trailing: last step is always included.
"""

LEADING = "leading"
LINSPACE = "linspace"
TRAILING = "trailing"

class PNDMScheduler(Scheduler):
"""
Expand All @@ -73,6 +87,7 @@ class PNDMScheduler(Scheduler):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
timestep_spacing: member of PNDMTimestepSpacing. Controls which timesteps are included during inference.
schedule_args: arguments to pass to the schedule function
"""

Expand All @@ -84,15 +99,19 @@ def __init__(
set_alpha_to_one: bool = False,
prediction_type: str = PNDMPredictionType.EPSILON,
steps_offset: int = 0,
timestep_spacing: str = "leading",
**schedule_args,
) -> None:
super().__init__(num_train_timesteps, schedule, **schedule_args)

if prediction_type not in PNDMPredictionType.__members__.values():
raise ValueError("Argument `prediction_type` must be a member of PNDMPredictionType")

self.prediction_type = prediction_type

if timestep_spacing not in PNDMTimestepSpacing.__members__.values():
raise ValueError("Argument `timestep_spacing` must be a member of PNDMTimestepSpacing")
self.timestep_spacing = timestep_spacing

self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]

# standard deviation of the initial noise distribution
Expand Down Expand Up @@ -132,10 +151,16 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N

self.num_inference_steps = num_inference_steps
step_ratio = self.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64)
self._timesteps += self.steps_offset
if self.timestep_spacing == PNDMTimestepSpacing.LEADING:
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64)
self._timesteps += self.steps_offset
elif self.timestep_spacing == PNDMTimestepSpacing.LINSPACE:
self._timesteps = np.linspace(0, self.num_train_timesteps-1, self.num_inference_steps, dtype=np.int64)
elif self.timestep_spacing == PNDMTimestepSpacing.TRAILING:
self._timesteps = np.round(np.flip(np.arange(self.num_train_timesteps, 0, -step_ratio)).astype(np.int64))
self._timesteps -= 1

if self.skip_prk_steps:
# for some models like stable diffusion the prk steps can/should be skipped to
Expand Down