Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Enforce the zero terminal SNR to schedulers #397

Closed
Warvito opened this issue May 25, 2023 · 13 comments · May be fixed by #404
Closed

Enforce the zero terminal SNR to schedulers #397

Warvito opened this issue May 25, 2023 · 13 comments · May be fixed by #404

Comments

@Warvito
Copy link
Collaborator

Warvito commented May 25, 2023

According to "Common Diffusion Noise Schedules and Sample Steps are Flawed" (https://arxiv.org/pdf/2305.08891.pdf), the implementation of most schedulers do not use t = T in the sampling process. We should include the corrections to enforce the zero terminal SNR

@ericspod
Copy link
Member

This could be good to do but should be motivated with our own experiment to show there's a meaningful difference.

@marksgraham
Copy link
Collaborator

Starting to play with this, will let you know what I find out

@marksgraham
Copy link
Collaborator

So I've found that starting from t=T helps for my application, but the zero terminal SNR hinders performance. However, my use cache is a bit niche (segmentation). I'm planning to re-run the FID tutorial with the changes to see if there is any difference

@Warvito
Copy link
Collaborator Author

Warvito commented Jun 6, 2023

Hi @marksgraham , this implementation might help with our code too (huggingface/diffusers#3664)

@marksgraham
Copy link
Collaborator

Thanks walter. Interesting to see they aren't totally convinced by the updates.

I like the idea of allowing the user to select the method for timestep discretisation (e.g. trailing, leading) as an argument to the scheduler.

I think we could just make the noise schedules with snr=0 at t=T available as new options and keep the current ones as default.

@ericspod
Copy link
Member

ericspod commented Jun 7, 2023

How would we want to implement this feature? In the Scheduler constructor we could add an argument for a rescale function which we'd apply to self.betas, self.alphas, self.alphas_cumprod:

def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", rescale_func: Callable | None = None, **schedule_args) -> None:
        super().__init__()
        schedule_args["num_train_timesteps"] = num_train_timesteps
        noise_sched = NoiseSchedules[schedule](**schedule_args)

        # set betas, alphas, alphas_cumprod based off return value from noise function
        if isinstance(noise_sched, tuple):
            self.betas, self.alphas, self.alphas_cumprod = noise_sched
        else:
            self.betas = noise_sched
            self.alphas = 1.0 - self.betas
            self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

        if rescale_func is not None:
           self.betas, self.alphas, self.alphas_cumprod = rescale_func(self.betas, self.alphas, self.alphas_cumprod)
        ...

@marksgraham
Copy link
Collaborator

Allowing for a generic function to be supplied strikes me as a bit too general, given a user can now easily specify custom beta schedules. I think we could provide a rescale_to_snr0 flag, similar to what you've proposed, or just add a new set of beta schedule options which enforce snr=0 at t=T, to complement the existing ones (e.g. linear_beta_snr0, scaled_linear_beta_snr0)

@marksgraham
Copy link
Collaborator

As it stands, I think it makes sense to allow users to specify how they want to do the timestep spacing. I've found it helps in some of my applications. I tried to get some results for FID score son MNIST, but the results are unclear. I propose we give users the option and let them decide what is right for their application. There is a suggestion of how to do it in the linked PR, what do you guys think? If it looks OK i'll implement in the other schedulers.

I haven't managed to get models with SNR=0 at t=T to train well at all, so I'm reluctant to implement it. I also note they weren't fully convinced of its utility in the huggingface discussion.

@sRassmann
Copy link

Hi, I found an related issue with the DDIM scheduler. When using a scheduler with a steep terminal SNR decay like a cosine scheduler*, my results are super bad:

image

However, if I change from the Leading (notation from the Table 2 in the Common Diffusion Noise Schedules and Sample Steps are Flawed paper) to Linspace method for timestep spacing, my results are a lot better:

    def set_timesteps(
        self, num_inference_steps: int, device: str | torch.device | None = None
    ) -> None:
        self.num_inference_steps = num_inference_steps
        step_ratio = self.num_train_timesteps // self.num_inference_steps
        # creates integer timesteps by multiplying by ratio
        # casting to int to avoid issues when num_inference_step is power of 3
        # timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)  # Leading 
        timesteps = (
            np.linspace(0, self.num_train_timesteps - 1, num_inference_steps)
            .round()[::-1]
            .copy()
            .astype(np.int64)
        )  # Linspace
        self.timesteps = torch.from_numpy(timesteps).to(device)
        self.timesteps += self.steps_offset

image

Intuitively, I guess it makes sense that skipping the early, steep steps in the SNR is an issue for a cosine scheduler, while the first part likely does not matter so much in the case of something like a linear schedule.

Anyways, how about letting the user decide on which scheme to use similar as diffusers handles it?

*Note: I am using a custom implementation for the cosine schedule function

def betas_for_alpha_bar(
    num_diffusion_timesteps, alpha_bar, max_beta=0.999
):  # https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L45C1-L62C27
    betas = []
    for i in range(num_diffusion_timesteps):
        t1 = i / num_diffusion_timesteps
        t2 = (i + 1) / num_diffusion_timesteps
        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
    return torch.tensor(betas)


@NoiseSchedules.add_def("cosine_poly", "Cosine schedule")
def _cosine_beta(num_train_timesteps: int, s: float = 8e-3, order: float = 2, *args):
    return betas_for_alpha_bar(
        num_train_timesteps,
        lambda t: np.cos((t + s) / (1 + s) * np.pi / 2) ** order,
    )

@marksgraham
Copy link
Collaborator

Hi, sorry for the slow response.

We've had another report of bad results with the cosine scheduler here too.

I actually started implementing this, but closed it because I couldn't find any benefit, but it does seem worth implementing. My closed PR is here.

Do you have any interest in resurrecting it and doing a new PR? If not, will do it but it won't be a priority for a while as we will be focusing on integration with MONAI core for the moment.

@sRassmann
Copy link

No worries and still thanks for your comment!

I want to look into it, though I have some deadlines incoming so I fear I won't have time in the next couple of weeks.

@virginiafdez
Copy link
Contributor

Can this issue be closed, since there is a new issue #489 open?

@marksgraham
Copy link
Collaborator

I'd say do :)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants