Skip to content

Commit f7b60cf

Browse files
committed
Add test for ARIMA exogenous variable support
1 parent a67dd7d commit f7b60cf

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

aeon/forecasting/stats/tests/test_arima.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,60 @@ def test_autoarima_forecast_is_consistent_with_wrapped():
190190
forecaster = AutoARIMA()
191191
val = forecaster._forecast(y)
192192
assert np.isclose(val, forecaster.final_model_.forecast_)
193+
194+
195+
def test_arima_with_exog_basic_fit_predict():
196+
"""Test ARIMA fit and predict with exogenous variables."""
197+
y_local = np.arange(50, dtype=float)
198+
exog = np.random.RandomState(42).randn(50, 2)
199+
200+
model = ARIMA(p=1, d=0, q=1)
201+
model.fit(y_local, exog=exog)
202+
pred = model.predict(y_local, exog=exog[-1:].copy())
203+
204+
assert isinstance(pred, float)
205+
assert np.isfinite(pred)
206+
207+
208+
def test_arima_exog_shape_mismatch_raises():
209+
"""Test that exogenous shape mismatches raise ValueError."""
210+
y_local = np.arange(20, dtype=float)
211+
exog = np.random.RandomState(0).randn(20, 3)
212+
213+
model = ARIMA(p=1, d=0, q=1)
214+
215+
with pytest.raises(ValueError):
216+
model.fit(y_local, exog=np.random.randn(10, 3))
217+
218+
model.fit(y_local, exog=exog)
219+
220+
with pytest.raises(ValueError):
221+
model.predict(y_local, exog=np.random.randn(1, 5))
222+
223+
224+
def test_arima_iterative_forecast_with_exog():
225+
"""Test multi-step forecast with future exogenous variables."""
226+
y_local = np.arange(40, dtype=float)
227+
exog = np.random.RandomState(1).randn(40, 2)
228+
229+
model = ARIMA(p=1, d=1, q=1)
230+
model.fit(y_local, exog=exog)
231+
232+
h = 5
233+
future_exog = np.random.RandomState(2).randn(h, 2)
234+
preds = model.iterative_forecast(y_local, prediction_horizon=h, exog=future_exog)
235+
236+
assert preds.shape == (h,)
237+
assert np.all(np.isfinite(preds))
238+
239+
240+
def test_arima_no_exog_backward_compatibility():
241+
"""Test ARIMA works normally when no exogenous variables are provided."""
242+
y_local = np.arange(30, dtype=float)
243+
244+
model = ARIMA(p=1, d=1, q=1)
245+
model.fit(y_local)
246+
pred = model.predict(y_local)
247+
248+
assert isinstance(pred, float)
249+
assert np.isfinite(pred)

0 commit comments

Comments
 (0)