9
9
import pytest
10
10
11
11
from numpy .testing import assert_allclose
12
+ from pymc .testing import mock_sample_setup_and_teardown
13
+ from pytensor .compile import SharedVariable
14
+ from pytensor .graph .basic import graph_inputs
12
15
13
16
from pymc_extras .statespace .core .statespace import FILTER_FACTORY , PyMCStateSpace
14
17
from pymc_extras .statespace .models import structural as st
30
33
floatX = pytensor .config .floatX
31
34
nile = load_nile_test_data ()
32
35
ALL_SAMPLE_OUTPUTS = MATRIX_NAMES + FILTER_OUTPUT_NAMES + SMOOTHER_OUTPUT_NAMES
36
+ mock_pymc_sample = pytest .fixture (scope = "session" )(mock_sample_setup_and_teardown )
33
37
34
38
35
39
def make_statespace_mod (k_endog , k_states , k_posdef , filter_type , verbose = False , data_info = None ):
@@ -170,7 +174,7 @@ def exog_pymc_mod(exog_ss_mod, exog_data):
170
174
)
171
175
beta_exog = pm .Normal ("beta_exog" , mu = 0 , sigma = 1 , dims = ["exog_state" ])
172
176
173
- exog_ss_mod .build_statespace_graph (exog_data ["y" ])
177
+ exog_ss_mod .build_statespace_graph (exog_data ["y" ], save_kalman_filter_outputs_in_idata = True )
174
178
175
179
return struct_model
176
180
@@ -212,7 +216,7 @@ def pymc_mod_no_exog_dt(ss_mod_no_exog_dt, rng):
212
216
213
217
214
218
@pytest .fixture (scope = "session" )
215
- def idata (pymc_mod , rng ):
219
+ def idata (pymc_mod , rng , mock_pymc_sample ):
216
220
with pymc_mod :
217
221
idata = pm .sample (draws = 10 , tune = 0 , chains = 1 , random_seed = rng )
218
222
idata_prior = pm .sample_prior_predictive (draws = 10 , random_seed = rng )
@@ -222,7 +226,7 @@ def idata(pymc_mod, rng):
222
226
223
227
224
228
@pytest .fixture (scope = "session" )
225
- def idata_exog (exog_pymc_mod , rng ):
229
+ def idata_exog (exog_pymc_mod , rng , mock_pymc_sample ):
226
230
with exog_pymc_mod :
227
231
idata = pm .sample (draws = 10 , tune = 0 , chains = 1 , random_seed = rng )
228
232
idata_prior = pm .sample_prior_predictive (draws = 10 , random_seed = rng )
@@ -231,7 +235,7 @@ def idata_exog(exog_pymc_mod, rng):
231
235
232
236
233
237
@pytest .fixture (scope = "session" )
234
- def idata_no_exog (pymc_mod_no_exog , rng ):
238
+ def idata_no_exog (pymc_mod_no_exog , rng , mock_pymc_sample ):
235
239
with pymc_mod_no_exog :
236
240
idata = pm .sample (draws = 10 , tune = 0 , chains = 1 , random_seed = rng )
237
241
idata_prior = pm .sample_prior_predictive (draws = 10 , random_seed = rng )
@@ -240,7 +244,7 @@ def idata_no_exog(pymc_mod_no_exog, rng):
240
244
241
245
242
246
@pytest .fixture (scope = "session" )
243
- def idata_no_exog_dt (pymc_mod_no_exog_dt , rng ):
247
+ def idata_no_exog_dt (pymc_mod_no_exog_dt , rng , mock_pymc_sample ):
244
248
with pymc_mod_no_exog_dt :
245
249
idata = pm .sample (draws = 10 , tune = 0 , chains = 1 , random_seed = rng )
246
250
idata_prior = pm .sample_prior_predictive (draws = 10 , random_seed = rng )
@@ -895,6 +899,93 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
895
899
assert_allclose (regression_effect , regression_effect_expected )
896
900
897
901
902
+ @pytest .mark .filterwarnings ("ignore:Provided data contains missing values" )
903
+ @pytest .mark .filterwarnings ("ignore:The RandomType SharedVariables" )
904
+ @pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
905
+ @pytest .mark .filterwarnings ("ignore:Skipping `CheckAndRaise` Op" )
906
+ @pytest .mark .filterwarnings ("ignore:No frequency was specific on the data's DateTimeIndex." )
907
+ def test_build_forecast_model (rng , exog_ss_mod , exog_pymc_mod , exog_data , idata_exog ):
908
+ data_before_build_forecast_model = {d .name : d .get_value () for d in exog_pymc_mod .data_vars }
909
+
910
+ scenario = pd .DataFrame (
911
+ {
912
+ "date" : pd .date_range (start = "2023-05-11" , end = "2023-05-20" , freq = "D" ),
913
+ "x1" : rng .choice (2 , size = 10 , replace = True ).astype (float ),
914
+ }
915
+ )
916
+ scenario .set_index ("date" , inplace = True )
917
+
918
+ time_index = exog_ss_mod ._get_fit_time_index ()
919
+ t0 , forecast_index = exog_ss_mod ._build_forecast_index (
920
+ time_index = time_index ,
921
+ start = exog_data .index [- 1 ],
922
+ end = scenario .index [- 1 ],
923
+ scenario = scenario ,
924
+ )
925
+
926
+ test_forecast_model = exog_ss_mod ._build_forecast_model (
927
+ time_index = time_index ,
928
+ t0 = t0 ,
929
+ forecast_index = forecast_index ,
930
+ scenario = scenario ,
931
+ filter_output = "predicted" ,
932
+ mvn_method = "svd" ,
933
+ )
934
+
935
+ frozen_shared_inputs = [
936
+ inpt
937
+ for inpt in graph_inputs ([test_forecast_model .x0_slice , test_forecast_model .P0_slice ])
938
+ if isinstance (inpt , SharedVariable )
939
+ and not isinstance (inpt .get_value (), np .random .Generator )
940
+ ]
941
+
942
+ assert (
943
+ len (frozen_shared_inputs ) == 0
944
+ ) # check there are no non-random generator SharedVariables in the frozen inputs
945
+
946
+ unfrozen_shared_inputs = [
947
+ inpt
948
+ for inpt in graph_inputs ([test_forecast_model .forecast_combined ])
949
+ if isinstance (inpt , SharedVariable )
950
+ and not isinstance (inpt .get_value (), np .random .Generator )
951
+ ]
952
+
953
+ # Check that there is one (in this case) unfrozen shared input and it corresponds to the exogenous data
954
+ assert len (unfrozen_shared_inputs ) == 1
955
+ assert unfrozen_shared_inputs [0 ].name == "data_exog"
956
+
957
+ data_after_build_forecast_model = {d .name : d .get_value () for d in test_forecast_model .data_vars }
958
+
959
+ with test_forecast_model :
960
+ dummy_obs_data = np .zeros ((len (forecast_index ), exog_ss_mod .k_endog ))
961
+ pm .set_data (
962
+ {"data_exog" : scenario } | {"data" : dummy_obs_data },
963
+ coords = {"data_time" : np .arange (len (forecast_index ))},
964
+ )
965
+ idata_forecast = pm .sample_posterior_predictive (
966
+ idata_exog , var_names = ["x0_slice" , "P0_slice" ]
967
+ )
968
+
969
+ np .testing .assert_allclose (
970
+ unfrozen_shared_inputs [0 ].get_value (), scenario ["x1" ].values .reshape ((- 1 , 1 ))
971
+ ) # ensure the replaced data matches the exogenous data
972
+
973
+ for k in data_before_build_forecast_model .keys ():
974
+ assert ( # check that the data needed to init the forecasts doesn't change
975
+ data_before_build_forecast_model [k ].mean () == data_after_build_forecast_model [k ].mean ()
976
+ )
977
+
978
+ # Check that the frozen states and covariances correctly match the sliced index
979
+ np .testing .assert_allclose (
980
+ idata_exog .posterior ["predicted_covariance" ].sel (time = t0 ).mean (("chain" , "draw" )).values ,
981
+ idata_forecast .posterior_predictive ["P0_slice" ].mean (("chain" , "draw" )).values ,
982
+ )
983
+ np .testing .assert_allclose (
984
+ idata_exog .posterior ["predicted_state" ].sel (time = t0 ).mean (("chain" , "draw" )).values ,
985
+ idata_forecast .posterior_predictive ["x0_slice" ].mean (("chain" , "draw" )).values ,
986
+ )
987
+
988
+
898
989
@pytest .mark .filterwarnings ("ignore:Provided data contains missing values" )
899
990
@pytest .mark .filterwarnings ("ignore:The RandomType SharedVariables" )
900
991
@pytest .mark .filterwarnings ("ignore:No time index found on the supplied data." )
0 commit comments