Skip to content

Commit a379b3b

Browse files
committed
BUG: Raise early on non-finite values in PSD (Welch) and ICA.fit (Fixes #13364)
1 parent 7cfcc6b commit a379b3b

File tree

5 files changed

+74
-1
lines changed

5 files changed

+74
-1
lines changed

doc/changes/devel/13364.bugfix.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fail early with a clear error when non-finite values (NaN/Inf) are present
2+
in PSD (Welch) and in ICA.fit, avoiding deep assertion failures (GH-13364).

mne/preprocessing/ica.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,13 @@ def _pre_whiten(self, data):
891891

892892
def _fit(self, data, fit_type):
893893
"""Aux function."""
894+
if not np.isfinite(data).all():
895+
raise ValueError(
896+
"Input data contains non-finite values (NaN/Inf). "
897+
"Please clean your data (e.g., high-pass filter, interpolate or drop "
898+
"contaminated segments) before calling ICA.fit()."
899+
)
900+
894901
random_state = check_random_state(self.random_state)
895902
n_channels, n_samples = data.shape
896903
self._compute_pre_whitener(data)

mne/preprocessing/tests/test_ica.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,3 +1718,30 @@ def test_ica_ch_types(ch_type):
17181718
for inst in [raw, epochs, evoked]:
17191719
ica.apply(inst)
17201720
ica.get_sources(inst)
1721+
1722+
1723+
@pytest.mark.filterwarnings(
1724+
"ignore:The data has not been high-pass filtered.:RuntimeWarning"
1725+
)
1726+
@pytest.mark.filterwarnings(
1727+
"ignore:invalid value encountered in subtract:RuntimeWarning"
1728+
)
1729+
def test_ica_rejects_nonfinite():
1730+
"""ICA.fit should fail early on NaN/Inf in the input data."""
1731+
info = create_info(["Fz", "Cz", "Pz", "Oz"], sfreq=100.0, ch_types="eeg")
1732+
rng = np.random.RandomState(1)
1733+
data = rng.randn(4, 1000)
1734+
1735+
# Case 1: NaN
1736+
raw = RawArray(data.copy(), info)
1737+
raw._data[0, 25] = np.nan
1738+
ica = ICA(n_components=2, random_state=0, method="fastica", max_iter="auto")
1739+
with pytest.raises(ValueError, match=r"Input data contains non[- ]?finite values"):
1740+
ica.fit(raw)
1741+
1742+
# Case 2: Inf
1743+
raw = RawArray(data.copy(), info)
1744+
raw._data[1, 50] = np.inf
1745+
ica = ICA(n_components=2, random_state=0, method="fastica", max_iter="auto")
1746+
with pytest.raises(ValueError, match=r"Input data contains non[- ]?finite values"):
1747+
ica.fit(raw)

mne/time_frequency/psd.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,16 @@ def psd_array_welch(
200200
del freq_mask
201201
freqs = freqs[freq_sl]
202202

203+
# Hard error on Inf inside analyzed samples
204+
step = max(n_per_seg - n_overlap, 1)
205+
n_segments = 1 + (n_times - n_per_seg) // step if n_times >= n_per_seg else 0
206+
analyzed_end = step * (n_segments - 1) + n_per_seg if n_segments > 0 else 0
207+
if analyzed_end > 0 and np.isinf(x[..., :analyzed_end]).any():
208+
raise ValueError(
209+
"Input data contains non-finite values (Inf) in the analyzed time span. "
210+
"Clean or drop bad segments before computing the PSD."
211+
)
212+
203213
# Parallelize across first N-1 dimensions
204214
logger.debug(
205215
f"Spectogram using {n_fft}-point FFT on {n_per_seg} samples with "
@@ -221,7 +231,12 @@ def psd_array_welch(
221231
good_mask = ~np.isnan(x)
222232
# NaNs originate from annot, so must match for all channels. Note that we CANNOT
223233
# use np.testing.assert_allclose() here; it is strict about shapes/broadcasting
224-
assert np.allclose(good_mask, good_mask[[0]], equal_nan=True)
234+
if not np.allclose(good_mask, good_mask[[0]], equal_nan=True):
235+
raise ValueError(
236+
"Input data contains NaN masks that are not aligned across channels; "
237+
"make NaN spans consistent across channels or clean/drop bad segments."
238+
)
239+
# assert np.allclose(good_mask, good_mask[[0]], equal_nan=True)
225240
t_onsets, t_offsets = _mask_to_onsets_offsets(good_mask[0])
226241
x_splits = [x[..., t_ons:t_off] for t_ons, t_off in zip(t_onsets, t_offsets)]
227242
# weights reflect the number of samples used from each span. For spans longer

mne/time_frequency/tests/test_psd.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,25 @@ def test_psd_array_welch_n_jobs():
226226
data = np.zeros((1, 2048))
227227
psd_array_welch(data, 1024, n_jobs=1)
228228
psd_array_welch(data, 1024, n_jobs=2)
229+
230+
231+
def test_psd_raises_on_inf_in_analyzed_window_array():
232+
"""psd_array_welch should fail if +Inf lies inside analyzed samples."""
233+
n_samples, n_fft, n_overlap = 2048, 256, 128
234+
rng = np.random.RandomState(0)
235+
x = rng.randn(1, n_samples)
236+
# Put +Inf inside the series; this falls within Welch windows
237+
x[0, 800] = np.inf
238+
with pytest.raises(ValueError, match="non[- ]?finite|NaN|Inf"):
239+
psd_array_welch(x, float(n_fft), n_fft=n_fft, n_overlap=n_overlap)
240+
241+
242+
def test_psd_raises_on_misaligned_nan_across_channels():
243+
"""If NaNs are present but masks are NOT aligned across channels, raise."""
244+
n_samples, n_fft, n_overlap = 2048, 256, 128
245+
rng = np.random.RandomState(42)
246+
x = rng.randn(2, n_samples)
247+
# NaN only in ch0; ch1 has no NaN => masks not aligned -> should raise
248+
x[0, 500] = np.nan
249+
with pytest.raises(ValueError, match="aligned|not aligned|non[- ]?finite|NaN|Inf"):
250+
psd_array_welch(x, float(n_fft), n_fft=n_fft, n_overlap=n_overlap)

0 commit comments

Comments
 (0)