Skip to content

Commit 4305bb8

Browse files
🐛 fix cosine noise scheduler (#8427)
Fixes: Update cosine noise scheduling ### Description In the current DDPMScheduler implementation, using the `cosine` noise schedule results in a division-by-zero issue during sampling. Specifically, `scheduler.alphas_cumprod[0] == 1.0`, which causes NaN values in the output image. You can reproduce the issue with the following snippet: ```python from monai.inferers import DiffusionInferer from monai.networks.nets import DiffusionModelUNet from monai.networks.schedulers import DDPMScheduler import torch N = 250 device = 'cuda' model = DiffusionModelUNet( spatial_dims=3, in_channels=1, out_channels=1, channels=[64, 64, 128], attention_levels=[False, False, True], num_head_channels=[0, 0, 128], num_res_blocks=2, ).to(device) scheduler = DDPMScheduler(num_train_timesteps=N, schedule="cosine").to(device) print(scheduler.alphas_cumprod[0]) inferer = DiffusionInferer(scheduler) scheduler.set_timesteps(num_inference_steps=N) noise = torch.randn((1, 1, 32, 40, 32)) noise = noise.to(device) image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=scheduler) assert torch.isfinite(image).any(), "Image has NaN values" ``` ### Fix This PR applies clipping to the beta values first and re-computes alphas_cumprod accordingly before returning in `monai/networks/schedulers/scheduler.py: 112`. This ensures numerical stability during sampling and prevents NaNs in the output. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Slava Shen <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
1 parent 0d19a72 commit 4305bb8

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

monai/networks/schedulers/scheduler.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,11 @@ def _cosine_beta(num_train_timesteps: int, s: float = 8e-3):
105105
x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1)
106106
alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
107107
alphas_cumprod /= alphas_cumprod[0].item()
108-
alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999)
109-
betas = 1.0 - alphas
110-
return betas, alphas, alphas_cumprod[:-1]
108+
betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
109+
betas = torch.clip(betas, 0.0, 0.999)
110+
alphas = 1.0 - betas
111+
alphas_cumprod = torch.cumprod(alphas, dim=0)
112+
return betas, alphas, alphas_cumprod
111113

112114

113115
class Scheduler(nn.Module):

0 commit comments

Comments
 (0)