@@ -133,12 +133,19 @@ def compute_z(x):
133
133
],
134
134
)
135
135
@pytest .mark .parametrize (
136
- "backend, gradient_backend" ,
137
- [("jax" , "jax" ), ("jax" , "pytensor" )],
136
+ "backend, gradient_backend, include_transformed " ,
137
+ [("jax" , "jax" , True ), ("jax" , "pytensor" , False )],
138
138
ids = str ,
139
139
)
140
140
def test_find_MAP (
141
- method , use_grad , use_hess , use_hessp , backend , gradient_backend : GradientBackend , rng
141
+ method ,
142
+ use_grad ,
143
+ use_hess ,
144
+ use_hessp ,
145
+ backend ,
146
+ gradient_backend : GradientBackend ,
147
+ include_transformed ,
148
+ rng ,
142
149
):
143
150
pytest .importorskip ("jax" )
144
151
@@ -154,12 +161,12 @@ def test_find_MAP(
154
161
use_hessp = use_hessp ,
155
162
progressbar = False ,
156
163
gradient_backend = gradient_backend ,
164
+ include_transformed = include_transformed ,
157
165
compile_kwargs = {"mode" : backend .upper ()},
158
166
maxiter = 5 ,
159
167
)
160
168
161
169
assert hasattr (idata , "posterior" )
162
- assert hasattr (idata , "unconstrained_posterior" )
163
170
assert hasattr (idata , "fit" )
164
171
assert hasattr (idata , "optimizer_result" )
165
172
assert hasattr (idata , "observed_data" )
@@ -169,9 +176,13 @@ def test_find_MAP(
169
176
assert posterior ["mu" ].shape == ()
170
177
assert posterior ["sigma" ].shape == ()
171
178
172
- unconstrained_posterior = idata .unconstrained_posterior .squeeze (["chain" , "draw" ])
173
- assert "sigma_log__" in unconstrained_posterior
174
- assert unconstrained_posterior ["sigma_log__" ].shape == ()
179
+ if include_transformed :
180
+ assert hasattr (idata , "unconstrained_posterior" )
181
+ unconstrained_posterior = idata .unconstrained_posterior .squeeze (["chain" , "draw" ])
182
+ assert "sigma_log__" in unconstrained_posterior
183
+ assert unconstrained_posterior ["sigma_log__" ].shape == ()
184
+ else :
185
+ assert not hasattr (idata , "unconstrained_posterior" )
175
186
176
187
177
188
@pytest .mark .parametrize (
0 commit comments