|
7 | 7 |
|
8 | 8 | @pytest.fixture
|
9 | 9 | def model_c():
|
10 |
| - with pm.Model() as mod: |
| 10 | + # TODO: Restructure tests so they check one dist at a time |
| 11 | + with pm.Model(coords=dict(a=range(5))) as mod: |
11 | 12 | m = pm.Normal("m")
|
12 | 13 | s = pm.LogNormal("s")
|
13 |
| - pm.Normal("g", m, s, shape=5) |
| 14 | + pm.Normal("g", m, s, dims="a") |
14 | 15 | pm.Exponential("e", scale=s, shape=7)
|
15 | 16 | return mod
|
16 | 17 |
|
17 | 18 |
|
18 | 19 | @pytest.fixture
|
19 | 20 | def model_nc():
|
20 |
| - with pm.Model() as mod: |
| 21 | + with pm.Model(coords=dict(a=range(5))) as mod: |
21 | 22 | m = pm.Normal("m")
|
22 | 23 | s = pm.LogNormal("s")
|
23 |
| - pm.Deterministic("g", pm.Normal("z", shape=5) * s + m) |
| 24 | + pm.Deterministic("g", pm.Normal("z", dims="a") * s + m) |
24 | 25 | pm.Deterministic("e", pm.Exponential("z_e", 1, shape=7) * s)
|
25 | 26 | return mod
|
26 | 27 |
|
@@ -102,3 +103,29 @@ def test_set_truncate(model_c: pm.Model):
|
102 | 103 | vip.truncate_lambda(g=0.2)
|
103 | 104 | np.testing.assert_allclose(vip.get_lambda()["g"], 1)
|
104 | 105 | np.testing.assert_allclose(vip.get_lambda()["m"], 0.9)
|
| 106 | + |
| 107 | + |
| 108 | +@pytest.mark.xfail(reason="FIX shape computation for lambda") |
| 109 | +def test_lambda_shape(): |
| 110 | + with pm.Model(coords=dict(a=[1, 2])) as model: |
| 111 | + b1 = pm.Normal("b1", dims="a") |
| 112 | + b2 = pm.Normal("b2", shape=2) |
| 113 | + b3 = pm.Normal("b3", size=2) |
| 114 | + b4 = pm.Normal("b4", np.asarray([1, 2])) |
| 115 | + model_v, vip = vip_reparametrize(model, ["b1", "b2", "b3", "b4"]) |
| 116 | + lams = vip.get_lambda() |
| 117 | + for v in ["b1", "b2", "b3", "b4"]: |
| 118 | + assert lams[v].shape == (2,), v |
| 119 | + |
| 120 | + |
| 121 | +@pytest.mark.xfail(reason="FIX shape computation for lambda") |
| 122 | +def test_lambda_shape_transformed_1d(): |
| 123 | + with pm.Model(coords=dict(a=[1, 2])) as model: |
| 124 | + b1 = pm.Exponential("b1", 1, dims="a") |
| 125 | + b2 = pm.Exponential("b2", 1, shape=2) |
| 126 | + b3 = pm.Exponential("b3", 1, size=2) |
| 127 | + b4 = pm.Exponential("b4", np.asarray([1, 2])) |
| 128 | + model_v, vip = vip_reparametrize(model, ["b1", "b2", "b3", "b4"]) |
| 129 | + lams = vip.get_lambda() |
| 130 | + for v in ["b1", "b2", "b3", "b4"]: |
| 131 | + assert lams[v].shape == (2,), v |
0 commit comments