-
Notifications
You must be signed in to change notification settings - Fork 101
Enforce the zero terminal SNR to schedulers #397
Comments
This could be good to do but should be motivated with our own experiment to show there's a meaningful difference. |
Starting to play with this, will let you know what I find out |
So I've found that starting from t=T helps for my application, but the zero terminal SNR hinders performance. However, my use cache is a bit niche (segmentation). I'm planning to re-run the FID tutorial with the changes to see if there is any difference |
Hi @marksgraham , this implementation might help with our code too (huggingface/diffusers#3664) |
Thanks walter. Interesting to see they aren't totally convinced by the updates. I like the idea of allowing the user to select the method for timestep discretisation (e.g. trailing, leading) as an argument to the scheduler. I think we could just make the noise schedules with snr=0 at t=T available as new options and keep the current ones as default. |
How would we want to implement this feature? In the Scheduler constructor we could add an argument for a rescale function which we'd apply to def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", rescale_func: Callable | None = None, **schedule_args) -> None:
super().__init__()
schedule_args["num_train_timesteps"] = num_train_timesteps
noise_sched = NoiseSchedules[schedule](**schedule_args)
# set betas, alphas, alphas_cumprod based off return value from noise function
if isinstance(noise_sched, tuple):
self.betas, self.alphas, self.alphas_cumprod = noise_sched
else:
self.betas = noise_sched
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
if rescale_func is not None:
self.betas, self.alphas, self.alphas_cumprod = rescale_func(self.betas, self.alphas, self.alphas_cumprod)
... |
Allowing for a generic function to be supplied strikes me as a bit too general, given a user can now easily specify custom beta schedules. I think we could provide a |
As it stands, I think it makes sense to allow users to specify how they want to do the timestep spacing. I've found it helps in some of my applications. I tried to get some results for FID score son MNIST, but the results are unclear. I propose we give users the option and let them decide what is right for their application. There is a suggestion of how to do it in the linked PR, what do you guys think? If it looks OK i'll implement in the other schedulers. I haven't managed to get models with SNR=0 at t=T to train well at all, so I'm reluctant to implement it. I also note they weren't fully convinced of its utility in the huggingface discussion. |
Hi, I found an related issue with the DDIM scheduler. When using a scheduler with a steep terminal SNR decay like a cosine scheduler*, my results are super bad: However, if I change from the def set_timesteps(
self, num_inference_steps: int, device: str | torch.device | None = None
) -> None:
self.num_inference_steps = num_inference_steps
step_ratio = self.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
# timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) # Leading
timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps)
.round()[::-1]
.copy()
.astype(np.int64)
) # Linspace
self.timesteps = torch.from_numpy(timesteps).to(device)
self.timesteps += self.steps_offset Intuitively, I guess it makes sense that skipping the early, steep steps in the SNR is an issue for a cosine scheduler, while the first part likely does not matter so much in the case of something like a linear schedule. Anyways, how about letting the user decide on which scheme to use similar as *Note: I am using a custom implementation for the cosine schedule function def betas_for_alpha_bar(
num_diffusion_timesteps, alpha_bar, max_beta=0.999
): # https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L45C1-L62C27
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas)
@NoiseSchedules.add_def("cosine_poly", "Cosine schedule")
def _cosine_beta(num_train_timesteps: int, s: float = 8e-3, order: float = 2, *args):
return betas_for_alpha_bar(
num_train_timesteps,
lambda t: np.cos((t + s) / (1 + s) * np.pi / 2) ** order,
) |
Hi, sorry for the slow response. We've had another report of bad results with the cosine scheduler here too. I actually started implementing this, but closed it because I couldn't find any benefit, but it does seem worth implementing. My closed PR is here. Do you have any interest in resurrecting it and doing a new PR? If not, will do it but it won't be a priority for a while as we will be focusing on integration with MONAI core for the moment. |
No worries and still thanks for your comment! I want to look into it, though I have some deadlines incoming so I fear I won't have time in the next couple of weeks. |
Can this issue be closed, since there is a new issue #489 open? |
I'd say do :) |
According to "Common Diffusion Noise Schedules and Sample Steps are Flawed" (https://arxiv.org/pdf/2305.08891.pdf), the implementation of most schedulers do not use t = T in the sampling process. We should include the corrections to enforce the zero terminal SNR
The text was updated successfully, but these errors were encountered: