Skip to content

Commit fa95fb9

Browse files
committed
start a test for the hsgp and 'by'
1 parent 3aecb7f commit fa95fb9

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

bambi/model_components.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,6 @@ def predict_common(
292292
# Remove columns of X that are associated with HSGP contributions
293293
# All the slices _must be_ deleted at the same time. Otherwise the slice objects don't
294294
# reflect the right columns of X at the time they're used
295-
# TODO: test this
296295
if hsgp_slices:
297296
X = np.delete(X, np.r_[tuple(hsgp_slices)], axis=1)
298297

tests/test_hsgp.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,35 @@ def test_minimal_1d_predicts(data_1d_single_group):
300300
new_idata = model.predict(idata, data=new_data, kind="pps", inplace=False)
301301
assert new_idata.posterior_predictive["y"].dims == ("chain", "draw", "y_obs")
302302
assert new_idata.posterior_predictive["y"].to_numpy().shape == (2, 500, 10)
303+
304+
305+
def test_multiple_hsgp_and_by(data_1d_multiple_groups):
306+
rng = np.random.default_rng(1234)
307+
df = data_1d_multiple_groups.copy()
308+
df["fac2"] = rng.choice(["a", "b", "c"], size=df.shape[0])
309+
310+
formula = "y ~ 1 + x0 + hsgp(x1, by=fac, m=10, c=2) + hsgp(x1, by=fac2, m=10, c=2)"
311+
model = bmb.Model(
312+
formula=formula,
313+
data=df,
314+
categorical=["fac"],
315+
)
316+
idata = model.fit(tune=400, draws=200, target_accept=0.9)
317+
318+
bmb.interpret.plot_predictions(
319+
model,
320+
idata,
321+
conditional="x1",
322+
subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"},
323+
);
324+
325+
bmb.interpret.plot_predictions(
326+
model,
327+
idata,
328+
conditional={
329+
"x1": np.linspace(0, 1, num=100),
330+
"fac2": ["a", "b", "c"]
331+
},
332+
legend=False,
333+
subplot_kwargs={"main": "x1", "group": "fac2", "panel": "fac2"},
334+
);

0 commit comments

Comments
 (0)