Skip to content

Commit c099fc4

Browse files
Forecast exogenous vars bug fix (#510)
* fixed bug in statespace forecast method when exogenous variables are present. * updated solution to handle input shapes correctly * simplified fix, renamed mu and cov for transparancy and added a check for the graph replacements * Refactor model builder logic out of `forecast` method * made slight change with _build_forecast_model and created a test case * made change to test_build_forecast_model() to ensure data is replaced with pm.set_data method * added additional checks to test_build_forecast_model * added mock_sample_setup_and_teardown to statespace tests --------- Co-authored-by: jessegrabowski <[email protected]>
1 parent 120411f commit c099fc4

File tree

2 files changed

+174
-57
lines changed

2 files changed

+174
-57
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 78 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,6 +2047,69 @@ def _finalize_scenario_initialization(
20472047

20482048
return scenario
20492049

2050+
def _build_forecast_model(
2051+
self, time_index, t0, forecast_index, scenario, filter_output, mvn_method
2052+
):
2053+
filter_time_dim = TIME_DIM
2054+
temp_coords = self._fit_coords.copy()
2055+
2056+
dims = None
2057+
if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]):
2058+
dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
2059+
2060+
t0_idx = np.flatnonzero(time_index == t0)[0]
2061+
2062+
temp_coords["data_time"] = time_index
2063+
temp_coords[TIME_DIM] = forecast_index
2064+
2065+
mu_dims, cov_dims = None, None
2066+
if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]):
2067+
mu_dims = ["data_time", ALL_STATE_DIM]
2068+
cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM]
2069+
2070+
with pm.Model(coords=temp_coords) as forecast_model:
2071+
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
2072+
data_dims=["data_time", OBS_STATE_DIM],
2073+
)
2074+
2075+
group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
2076+
mu, cov = grouped_outputs[group_idx]
2077+
2078+
sub_dict = {
2079+
data_var: pt.as_tensor_variable(data_var.get_value(), name="data")
2080+
for data_var in forecast_model.data_vars
2081+
}
2082+
2083+
missing_data_vars = np.setdiff1d(
2084+
ar1=[*self.data_names, "data"], ar2=[k.name for k, _ in sub_dict.items()]
2085+
)
2086+
if missing_data_vars.size > 0:
2087+
raise ValueError(f"{missing_data_vars} data used for fitting not found!")
2088+
2089+
mu_frozen, cov_frozen = graph_replace([mu, cov], replace=sub_dict, strict=True)
2090+
2091+
x0 = pm.Deterministic(
2092+
"x0_slice", mu_frozen[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None
2093+
)
2094+
P0 = pm.Deterministic(
2095+
"P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
2096+
)
2097+
2098+
_ = LinearGaussianStateSpace(
2099+
"forecast",
2100+
x0,
2101+
P0,
2102+
*matrices,
2103+
steps=len(forecast_index),
2104+
dims=dims,
2105+
sequence_names=self.kalman_filter.seq_names,
2106+
k_endog=self.k_endog,
2107+
append_x0=False,
2108+
method=mvn_method,
2109+
)
2110+
2111+
return forecast_model
2112+
20502113
def forecast(
20512114
self,
20522115
idata: InferenceData,
@@ -2139,8 +2202,6 @@ def forecast(
21392202
the latent state trajectories: `y[t] = Z @ x[t] + nu[t]`, where `nu ~ N(0, H)`.
21402203
21412204
"""
2142-
filter_time_dim = TIME_DIM
2143-
21442205
_validate_filter_arg(filter_output)
21452206

21462207
compile_kwargs = kwargs.pop("compile_kwargs", {})
@@ -2185,58 +2246,23 @@ def forecast(
21852246
use_scenario_index=use_scenario_index,
21862247
)
21872248
scenario = self._finalize_scenario_initialization(scenario, forecast_index)
2188-
temp_coords = self._fit_coords.copy()
2189-
2190-
dims = None
2191-
if all([dim in temp_coords for dim in [filter_time_dim, ALL_STATE_DIM, OBS_STATE_DIM]]):
2192-
dims = [TIME_DIM, ALL_STATE_DIM, OBS_STATE_DIM]
2193-
2194-
t0_idx = np.flatnonzero(time_index == t0)[0]
2195-
2196-
temp_coords["data_time"] = time_index
2197-
temp_coords[TIME_DIM] = forecast_index
2198-
2199-
mu_dims, cov_dims = None, None
2200-
if all([dim in self._fit_coords for dim in [TIME_DIM, ALL_STATE_DIM, ALL_STATE_AUX_DIM]]):
2201-
mu_dims = ["data_time", ALL_STATE_DIM]
2202-
cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM]
2203-
2204-
with pm.Model(coords=temp_coords) as forecast_model:
2205-
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
2206-
scenario=scenario,
2207-
data_dims=["data_time", OBS_STATE_DIM],
2208-
)
2209-
2210-
for name in self.data_names:
2211-
if name in scenario.keys():
2212-
pm.set_data(
2213-
{"data": np.zeros((len(forecast_index), self.k_endog))},
2214-
coords={"data_time": np.arange(len(forecast_index))},
2215-
)
2216-
break
22172249

2218-
group_idx = FILTER_OUTPUT_TYPES.index(filter_output)
2219-
mu, cov = grouped_outputs[group_idx]
2220-
2221-
x0 = pm.Deterministic(
2222-
"x0_slice", mu[t0_idx], dims=mu_dims[1:] if mu_dims is not None else None
2223-
)
2224-
P0 = pm.Deterministic(
2225-
"P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
2226-
)
2250+
forecast_model = self._build_forecast_model(
2251+
time_index=time_index,
2252+
t0=t0,
2253+
forecast_index=forecast_index,
2254+
scenario=scenario,
2255+
filter_output=filter_output,
2256+
mvn_method=mvn_method,
2257+
)
22272258

2228-
_ = LinearGaussianStateSpace(
2229-
"forecast",
2230-
x0,
2231-
P0,
2232-
*matrices,
2233-
steps=len(forecast_index),
2234-
dims=dims,
2235-
sequence_names=self.kalman_filter.seq_names,
2236-
k_endog=self.k_endog,
2237-
append_x0=False,
2238-
method=mvn_method,
2239-
)
2259+
with forecast_model:
2260+
if scenario is not None:
2261+
dummy_obs_data = np.zeros((len(forecast_index), self.k_endog))
2262+
pm.set_data(
2263+
scenario | {"data": dummy_obs_data},
2264+
coords={"data_time": np.arange(len(forecast_index))},
2265+
)
22402266

22412267
forecast_model.rvs_to_initial_values = {
22422268
k: None for k in forecast_model.rvs_to_initial_values.keys()

tests/statespace/core/test_statespace.py

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import pytest
1010

1111
from numpy.testing import assert_allclose
12+
from pymc.testing import mock_sample_setup_and_teardown
13+
from pytensor.compile import SharedVariable
14+
from pytensor.graph.basic import graph_inputs
1215

1316
from pymc_extras.statespace.core.statespace import FILTER_FACTORY, PyMCStateSpace
1417
from pymc_extras.statespace.models import structural as st
@@ -30,6 +33,7 @@
3033
floatX = pytensor.config.floatX
3134
nile = load_nile_test_data()
3235
ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES
36+
mock_pymc_sample = pytest.fixture(scope="session")(mock_sample_setup_and_teardown)
3337

3438

3539
def make_statespace_mod(k_endog, k_states, k_posdef, filter_type, verbose=False, data_info=None):
@@ -170,7 +174,7 @@ def exog_pymc_mod(exog_ss_mod, exog_data):
170174
)
171175
beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["exog_state"])
172176

173-
exog_ss_mod.build_statespace_graph(exog_data["y"])
177+
exog_ss_mod.build_statespace_graph(exog_data["y"], save_kalman_filter_outputs_in_idata=True)
174178

175179
return struct_model
176180

@@ -212,7 +216,7 @@ def pymc_mod_no_exog_dt(ss_mod_no_exog_dt, rng):
212216

213217

214218
@pytest.fixture(scope="session")
215-
def idata(pymc_mod, rng):
219+
def idata(pymc_mod, rng, mock_pymc_sample):
216220
with pymc_mod:
217221
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
218222
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
@@ -222,7 +226,7 @@ def idata(pymc_mod, rng):
222226

223227

224228
@pytest.fixture(scope="session")
225-
def idata_exog(exog_pymc_mod, rng):
229+
def idata_exog(exog_pymc_mod, rng, mock_pymc_sample):
226230
with exog_pymc_mod:
227231
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
228232
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
@@ -231,7 +235,7 @@ def idata_exog(exog_pymc_mod, rng):
231235

232236

233237
@pytest.fixture(scope="session")
234-
def idata_no_exog(pymc_mod_no_exog, rng):
238+
def idata_no_exog(pymc_mod_no_exog, rng, mock_pymc_sample):
235239
with pymc_mod_no_exog:
236240
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
237241
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
@@ -240,7 +244,7 @@ def idata_no_exog(pymc_mod_no_exog, rng):
240244

241245

242246
@pytest.fixture(scope="session")
243-
def idata_no_exog_dt(pymc_mod_no_exog_dt, rng):
247+
def idata_no_exog_dt(pymc_mod_no_exog_dt, rng, mock_pymc_sample):
244248
with pymc_mod_no_exog_dt:
245249
idata = pm.sample(draws=10, tune=0, chains=1, random_seed=rng)
246250
idata_prior = pm.sample_prior_predictive(draws=10, random_seed=rng)
@@ -895,6 +899,93 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
895899
assert_allclose(regression_effect, regression_effect_expected)
896900

897901

902+
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
903+
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
904+
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
905+
@pytest.mark.filterwarnings("ignore:Skipping `CheckAndRaise` Op")
906+
@pytest.mark.filterwarnings("ignore:No frequency was specific on the data's DateTimeIndex.")
907+
def test_build_forecast_model(rng, exog_ss_mod, exog_pymc_mod, exog_data, idata_exog):
908+
data_before_build_forecast_model = {d.name: d.get_value() for d in exog_pymc_mod.data_vars}
909+
910+
scenario = pd.DataFrame(
911+
{
912+
"date": pd.date_range(start="2023-05-11", end="2023-05-20", freq="D"),
913+
"x1": rng.choice(2, size=10, replace=True).astype(float),
914+
}
915+
)
916+
scenario.set_index("date", inplace=True)
917+
918+
time_index = exog_ss_mod._get_fit_time_index()
919+
t0, forecast_index = exog_ss_mod._build_forecast_index(
920+
time_index=time_index,
921+
start=exog_data.index[-1],
922+
end=scenario.index[-1],
923+
scenario=scenario,
924+
)
925+
926+
test_forecast_model = exog_ss_mod._build_forecast_model(
927+
time_index=time_index,
928+
t0=t0,
929+
forecast_index=forecast_index,
930+
scenario=scenario,
931+
filter_output="predicted",
932+
mvn_method="svd",
933+
)
934+
935+
frozen_shared_inputs = [
936+
inpt
937+
for inpt in graph_inputs([test_forecast_model.x0_slice, test_forecast_model.P0_slice])
938+
if isinstance(inpt, SharedVariable)
939+
and not isinstance(inpt.get_value(), np.random.Generator)
940+
]
941+
942+
assert (
943+
len(frozen_shared_inputs) == 0
944+
) # check there are no non-random generator SharedVariables in the frozen inputs
945+
946+
unfrozen_shared_inputs = [
947+
inpt
948+
for inpt in graph_inputs([test_forecast_model.forecast_combined])
949+
if isinstance(inpt, SharedVariable)
950+
and not isinstance(inpt.get_value(), np.random.Generator)
951+
]
952+
953+
# Check that there is one (in this case) unfrozen shared input and it corresponds to the exogenous data
954+
assert len(unfrozen_shared_inputs) == 1
955+
assert unfrozen_shared_inputs[0].name == "data_exog"
956+
957+
data_after_build_forecast_model = {d.name: d.get_value() for d in test_forecast_model.data_vars}
958+
959+
with test_forecast_model:
960+
dummy_obs_data = np.zeros((len(forecast_index), exog_ss_mod.k_endog))
961+
pm.set_data(
962+
{"data_exog": scenario} | {"data": dummy_obs_data},
963+
coords={"data_time": np.arange(len(forecast_index))},
964+
)
965+
idata_forecast = pm.sample_posterior_predictive(
966+
idata_exog, var_names=["x0_slice", "P0_slice"]
967+
)
968+
969+
np.testing.assert_allclose(
970+
unfrozen_shared_inputs[0].get_value(), scenario["x1"].values.reshape((-1, 1))
971+
) # ensure the replaced data matches the exogenous data
972+
973+
for k in data_before_build_forecast_model.keys():
974+
assert ( # check that the data needed to init the forecasts doesn't change
975+
data_before_build_forecast_model[k].mean() == data_after_build_forecast_model[k].mean()
976+
)
977+
978+
# Check that the frozen states and covariances correctly match the sliced index
979+
np.testing.assert_allclose(
980+
idata_exog.posterior["predicted_covariance"].sel(time=t0).mean(("chain", "draw")).values,
981+
idata_forecast.posterior_predictive["P0_slice"].mean(("chain", "draw")).values,
982+
)
983+
np.testing.assert_allclose(
984+
idata_exog.posterior["predicted_state"].sel(time=t0).mean(("chain", "draw")).values,
985+
idata_forecast.posterior_predictive["x0_slice"].mean(("chain", "draw")).values,
986+
)
987+
988+
898989
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
899990
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
900991
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")

0 commit comments

Comments
 (0)