-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
Hi all and thanks for the awesome library 🚀
Since #21055 was merged (part of the version 2.5.3 release), the Learning Rate finders do not find the optimal learning rate anymore with noisy data.
Steps to reproduce (you can run below once with python-lightning<2.5.3 and once with >=2.5.3) are below. The example shows the lr finder behavior on synthetic noise-free and noisy losses.
import matplotlib.pyplot as plt
import numpy as np
from pytorch_lightning.callbacks.lr_finder import _LRFinder
use_lightning_le_2_5_3 = True
np.random.seed(42)
n_lrs = 100
# generate losses (noise-free and noisy)
x = np.linspace(-50, 50, n_lrs)
losses = (-x**2 - 30*x + 4000) / 1000.
noise = np.random.uniform(-0.1, 0.1, n_lrs)
losses_noisy = losses + noise
# generate exponential lrs (as done in _LRFinder)
lr_min = 1e-8
lr_max = 1.
lrs = []
for i in range(n_lrs):
r = (i + 1) / n_lrs
lrs.append(lr_min * (lr_max / lr_min) ** r)
fig, axes = plt.subplots(ncols=2, figsize=(10, 4))
for ax, losses_, name in zip(
axes,
[losses, losses_noisy],
["noise-free", "noisy"],
):
# fill in the results manually to simulate a lr_finder run without actual training
lr_finder = _LRFinder(mode="exponential", lr_min=lr_min, lr_max=lr_max, num_training=n_lrs)
lr_finder.results = {
"lr": lrs,
"loss": losses_
}
# plot with the suggested lr
lr_finder.plot(suggest=True, ax=ax)
ax.set_title(f"LR Finder ({name})")
plt.suptitle(f"pytorch-lightning{'<2.5.3' if use_lightning_le_2_5_3 else '>=2.5.3'}")
fig.tight_layout()
plt.show(block=True)
This will give you the following suggested lrs:


As you can see, before 2.5.3, the results weren't affected by noise in the data. After 2.5.3 it does result in suggesting a non-optimal lr in the lower learning rates, since the gradients using the newly introduced spacing have much larger magnitude in these areas.
Having noise in the losses is something that comes naturally in our actual use cases. I'm from the Darts forecasting library where we built our neural network framework around lightning.
Expected behavior
The LRFinder suggestions should ideally not be affected by minor noise in the data (of course with large noise that's a different story).
Proposed solutions
That's a good question.. Maybe we could de-noise / filtering / moving average the loss data in the LRFinder before the gradients are computed?
Thanks again, and if you need more info, let me know!
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
Error messages and logs
# Error messages and logs here please
Environment
Current environment
* CUDA:
- GPU: None
- available: False
- version: None
* Lightning:
- lightning: 2.5.3
- lightning-utilities: 0.15.2
- pytorch-lightning: 2.5.2
- torch: 2.8.0
- torchmetrics: 1.8.1
* System:
- OS: Darwin
- architecture:
- 64bit
-
- processor: arm
- python: 3.12.11
- release: 24.3.0
- version: Darwin Kernel Version 24.3.0: Thu Jan 2 20:24:16 PST 2025; root:xnu-
* installed with `pip`
More info
No response
cc @lantiga