|
| 1 | +from collections.abc import Sequence |
1 | 2 | from functools import partial
|
2 | 3 |
|
3 | 4 | import numpy as np
|
@@ -349,6 +350,59 @@ def test_sampling_methods(group, kind, ss_mod, idata, rng):
|
349 | 350 | assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values))
|
350 | 351 |
|
351 | 352 |
|
| 353 | +@pytest.mark.filterwarnings("ignore:Provided data contains missing values") |
| 354 | +def test_sample_conditional_with_time_varying(): |
| 355 | + class TVCovariance(PyMCStateSpace): |
| 356 | + def __init__(self): |
| 357 | + super().__init__(k_states=1, k_endog=1, k_posdef=1) |
| 358 | + |
| 359 | + def make_symbolic_graph(self) -> None: |
| 360 | + self.ssm["transition", 0, 0] = 1.0 |
| 361 | + |
| 362 | + self.ssm["design", 0, 0] = 1.0 |
| 363 | + |
| 364 | + sigma_cov = self.make_and_register_variable("sigma_cov", (None,)) |
| 365 | + self.ssm["state_cov"] = sigma_cov[:, None, None] ** 2 |
| 366 | + |
| 367 | + @property |
| 368 | + def param_names(self) -> list[str]: |
| 369 | + return ["sigma_cov"] |
| 370 | + |
| 371 | + @property |
| 372 | + def coords(self) -> dict[str, Sequence[str]]: |
| 373 | + return make_default_coords(self) |
| 374 | + |
| 375 | + @property |
| 376 | + def state_names(self) -> list[str]: |
| 377 | + return ["level"] |
| 378 | + |
| 379 | + @property |
| 380 | + def observed_states(self) -> list[str]: |
| 381 | + return ["level"] |
| 382 | + |
| 383 | + @property |
| 384 | + def shock_names(self) -> list[str]: |
| 385 | + return ["level"] |
| 386 | + |
| 387 | + ss_mod = TVCovariance() |
| 388 | + empty_data = pd.DataFrame( |
| 389 | + np.nan, index=pd.date_range("2020-01-01", periods=100, freq="D"), columns=["data"] |
| 390 | + ) |
| 391 | + |
| 392 | + coords = ss_mod.coords |
| 393 | + coords["time"] = empty_data.index |
| 394 | + with pm.Model(coords=coords) as mod: |
| 395 | + log_sigma_cov = pm.Normal("log_sigma_cov", mu=0, sigma=0.1, dims=["time"]) |
| 396 | + pm.Deterministic("sigma_cov", pm.math.exp(log_sigma_cov.cumsum()), dims=["time"]) |
| 397 | + |
| 398 | + ss_mod.build_statespace_graph(data=empty_data) |
| 399 | + |
| 400 | + prior = pm.sample_prior_predictive(10) |
| 401 | + |
| 402 | + ss_mod.sample_unconditional_prior(prior) |
| 403 | + ss_mod.sample_conditional_prior(prior) |
| 404 | + |
| 405 | + |
352 | 406 | def _make_time_idx(mod, use_datetime_index=True):
|
353 | 407 | if use_datetime_index:
|
354 | 408 | mod._fit_coords["time"] = nile.index
|
|
0 commit comments