Skip to content

Commit 7ce4ac8

Browse files
authored
fix autoreparam because dims are no longer static (#363)
1 parent d50742d commit 7ce4ac8

File tree

2 files changed

+45
-12
lines changed

2 files changed

+45
-12
lines changed

pymc_experimental/model/transforms/autoreparam.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from dataclasses import dataclass
23
from functools import singledispatch
34
from typing import Dict, List, Optional, Sequence, Tuple, Union
@@ -8,7 +9,6 @@
89
import pytensor.tensor as pt
910
import scipy.special
1011
from pymc.distributions import SymbolicRandomVariable
11-
from pymc.exceptions import NotConstantValueError
1212
from pymc.logprob.transforms import Transform
1313
from pymc.model.fgraph import (
1414
ModelDeterministic,
@@ -19,10 +19,12 @@
1919
model_from_fgraph,
2020
model_named,
2121
)
22-
from pymc.pytensorf import constant_fold, toposort_replace
22+
from pymc.pytensorf import toposort_replace
2323
from pytensor.graph.basic import Apply, Variable
2424
from pytensor.tensor.random.op import RandomVariable
2525

26+
_log = logging.getLogger("pmx")
27+
2628

2729
@dataclass
2830
class VIP:
@@ -174,15 +176,19 @@ def vip_reparam_node(
174176
) -> Tuple[ModelDeterministic, ModelNamed]:
175177
if not isinstance(node.op, RandomVariable | SymbolicRandomVariable):
176178
raise TypeError("Op should be RandomVariable type")
177-
rv = node.default_output()
178-
try:
179-
[rv_shape] = constant_fold([rv.shape])
180-
except NotConstantValueError:
181-
raise ValueError("Size should be static for autoreparametrization.")
179+
# FIXME: This is wrong when size is None
180+
_, size, *_ = node.inputs
181+
eval_size = size.eval(mode="FAST_COMPILE")
182+
if eval_size is not None:
183+
rv_shape = tuple(eval_size)
184+
else:
185+
rv_shape = ()
186+
lam_name = f"{name}::lam_logit__"
187+
_log.debug(f"Creating {lam_name} with shape of {rv_shape}")
182188
logit_lam_ = pytensor.shared(
183189
np.zeros(rv_shape),
184190
shape=rv_shape,
185-
name=f"{name}::lam_logit__",
191+
name=lam_name,
186192
)
187193
logit_lam = model_named(logit_lam_, *dims)
188194
lam = pt.sigmoid(logit_lam)

tests/model/transforms/test_autoreparam.py

+31-4
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,21 @@
77

88
@pytest.fixture
99
def model_c():
10-
with pm.Model() as mod:
10+
# TODO: Restructure tests so they check one dist at a time
11+
with pm.Model(coords=dict(a=range(5))) as mod:
1112
m = pm.Normal("m")
1213
s = pm.LogNormal("s")
13-
pm.Normal("g", m, s, shape=5)
14+
pm.Normal("g", m, s, dims="a")
1415
pm.Exponential("e", scale=s, shape=7)
1516
return mod
1617

1718

1819
@pytest.fixture
1920
def model_nc():
20-
with pm.Model() as mod:
21+
with pm.Model(coords=dict(a=range(5))) as mod:
2122
m = pm.Normal("m")
2223
s = pm.LogNormal("s")
23-
pm.Deterministic("g", pm.Normal("z", shape=5) * s + m)
24+
pm.Deterministic("g", pm.Normal("z", dims="a") * s + m)
2425
pm.Deterministic("e", pm.Exponential("z_e", 1, shape=7) * s)
2526
return mod
2627

@@ -102,3 +103,29 @@ def test_set_truncate(model_c: pm.Model):
102103
vip.truncate_lambda(g=0.2)
103104
np.testing.assert_allclose(vip.get_lambda()["g"], 1)
104105
np.testing.assert_allclose(vip.get_lambda()["m"], 0.9)
106+
107+
108+
@pytest.mark.xfail(reason="FIX shape computation for lambda")
109+
def test_lambda_shape():
110+
with pm.Model(coords=dict(a=[1, 2])) as model:
111+
b1 = pm.Normal("b1", dims="a")
112+
b2 = pm.Normal("b2", shape=2)
113+
b3 = pm.Normal("b3", size=2)
114+
b4 = pm.Normal("b4", np.asarray([1, 2]))
115+
model_v, vip = vip_reparametrize(model, ["b1", "b2", "b3", "b4"])
116+
lams = vip.get_lambda()
117+
for v in ["b1", "b2", "b3", "b4"]:
118+
assert lams[v].shape == (2,), v
119+
120+
121+
@pytest.mark.xfail(reason="FIX shape computation for lambda")
122+
def test_lambda_shape_transformed_1d():
123+
with pm.Model(coords=dict(a=[1, 2])) as model:
124+
b1 = pm.Exponential("b1", 1, dims="a")
125+
b2 = pm.Exponential("b2", 1, shape=2)
126+
b3 = pm.Exponential("b3", 1, size=2)
127+
b4 = pm.Exponential("b4", np.asarray([1, 2]))
128+
model_v, vip = vip_reparametrize(model, ["b1", "b2", "b3", "b4"])
129+
lams = vip.get_lambda()
130+
for v in ["b1", "b2", "b3", "b4"]:
131+
assert lams[v].shape == (2,), v

0 commit comments

Comments
 (0)