Skip to content

Commit 9c7a6fb

Browse files
Check model coords for unknown shapes when building predictive models (#413)
1 parent dcc353c commit 9c7a6fb

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

pymc_extras/statespace/core/statespace.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -983,10 +983,31 @@ def _build_dummy_graph(self) -> None:
983983
list[pm.Flat]
984984
A list of pm.Flat variables representing all parameters estimated by the model.
985985
"""
986+
987+
def infer_variable_shape(name):
988+
shape = self._name_to_variable[name].type.shape
989+
if not any(dim is None for dim in shape):
990+
return shape
991+
992+
dim_names = self._fit_dims.get(name, None)
993+
if dim_names is None:
994+
raise ValueError(
995+
f"Could not infer shape for {name}, because it was not given coords during model"
996+
f"fitting"
997+
)
998+
999+
shape_from_coords = tuple([len(self._fit_coords[dim]) for dim in dim_names])
1000+
return tuple(
1001+
[
1002+
shape[i] if shape[i] is not None else shape_from_coords[i]
1003+
for i in range(len(shape))
1004+
]
1005+
)
1006+
9861007
for name in self.param_names:
9871008
pm.Flat(
9881009
name,
989-
shape=self._name_to_variable[name].type.shape,
1010+
shape=infer_variable_shape(name),
9901011
dims=self._fit_dims.get(name, None),
9911012
)
9921013

tests/statespace/test_statespace.py

+54
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Sequence
12
from functools import partial
23

34
import numpy as np
@@ -349,6 +350,59 @@ def test_sampling_methods(group, kind, ss_mod, idata, rng):
349350
assert not np.any(np.isnan(test_idata[f"{group}_{output}"].values))
350351

351352

353+
@pytest.mark.filterwarnings("ignore:Provided data contains missing values")
354+
def test_sample_conditional_with_time_varying():
355+
class TVCovariance(PyMCStateSpace):
356+
def __init__(self):
357+
super().__init__(k_states=1, k_endog=1, k_posdef=1)
358+
359+
def make_symbolic_graph(self) -> None:
360+
self.ssm["transition", 0, 0] = 1.0
361+
362+
self.ssm["design", 0, 0] = 1.0
363+
364+
sigma_cov = self.make_and_register_variable("sigma_cov", (None,))
365+
self.ssm["state_cov"] = sigma_cov[:, None, None] ** 2
366+
367+
@property
368+
def param_names(self) -> list[str]:
369+
return ["sigma_cov"]
370+
371+
@property
372+
def coords(self) -> dict[str, Sequence[str]]:
373+
return make_default_coords(self)
374+
375+
@property
376+
def state_names(self) -> list[str]:
377+
return ["level"]
378+
379+
@property
380+
def observed_states(self) -> list[str]:
381+
return ["level"]
382+
383+
@property
384+
def shock_names(self) -> list[str]:
385+
return ["level"]
386+
387+
ss_mod = TVCovariance()
388+
empty_data = pd.DataFrame(
389+
np.nan, index=pd.date_range("2020-01-01", periods=100, freq="D"), columns=["data"]
390+
)
391+
392+
coords = ss_mod.coords
393+
coords["time"] = empty_data.index
394+
with pm.Model(coords=coords) as mod:
395+
log_sigma_cov = pm.Normal("log_sigma_cov", mu=0, sigma=0.1, dims=["time"])
396+
pm.Deterministic("sigma_cov", pm.math.exp(log_sigma_cov.cumsum()), dims=["time"])
397+
398+
ss_mod.build_statespace_graph(data=empty_data)
399+
400+
prior = pm.sample_prior_predictive(10)
401+
402+
ss_mod.sample_unconditional_prior(prior)
403+
ss_mod.sample_conditional_prior(prior)
404+
405+
352406
def _make_time_idx(mod, use_datetime_index=True):
353407
if use_datetime_index:
354408
mod._fit_coords["time"] = nile.index

0 commit comments

Comments
 (0)