You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fixes: Update cosine noise scheduling
### Description
In the current DDPMScheduler implementation, using the `cosine` noise
schedule results in a division-by-zero issue during sampling.
Specifically, `scheduler.alphas_cumprod[0] == 1.0`, which causes NaN
values in the output image.
You can reproduce the issue with the following snippet:
```python
from monai.inferers import DiffusionInferer
from monai.networks.nets import DiffusionModelUNet
from monai.networks.schedulers import DDPMScheduler
import torch
N = 250
device = 'cuda'
model = DiffusionModelUNet(
spatial_dims=3,
in_channels=1,
out_channels=1,
channels=[64, 64, 128],
attention_levels=[False, False, True],
num_head_channels=[0, 0, 128],
num_res_blocks=2,
).to(device)
scheduler = DDPMScheduler(num_train_timesteps=N, schedule="cosine").to(device)
print(scheduler.alphas_cumprod[0])
inferer = DiffusionInferer(scheduler)
scheduler.set_timesteps(num_inference_steps=N)
noise = torch.randn((1, 1, 32, 40, 32))
noise = noise.to(device)
image = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=scheduler)
assert torch.isfinite(image).any(), "Image has NaN values"
```
### Fix
This PR applies clipping to the beta values first and re-computes
alphas_cumprod accordingly before returning in
`monai/networks/schedulers/scheduler.py: 112`. This ensures numerical
stability during sampling and prevents NaNs in the output.
### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
Signed-off-by: Slava Shen <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
0 commit comments