Skip to content

Commit 30ffa95

Browse files
committed
Add support for multiple BART random variables per model.
1 parent b831618 commit 30ffa95

File tree

3 files changed

+94
-10
lines changed

3 files changed

+94
-10
lines changed

pymc_bart/bart.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from pytensor.tensor.variable import TensorVariable
3030

3131
from .split_rules import SplitRule
32-
from .tree import Tree
3332
from .utils import TensorLike, _sample_posterior
3433

3534
__all__ = ["BART"]
@@ -42,7 +41,6 @@ class BARTRV(RandomVariable):
4241
signature = "(m,n),(m),(),(),() -> (m)"
4342
dtype: str = "floatX"
4443
_print_name: tuple[str, str] = ("BART", "\\operatorname{BART}")
45-
all_trees = list[list[list[Tree]]]
4644

4745
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed
4846
idx = dist_params[0].ndim - 2
@@ -55,7 +53,7 @@ def rng_fn( # pylint: disable=W0237
5553
if not size:
5654
size = None
5755

58-
if not cls.all_trees:
56+
if not hasattr(cls, "all_trees") or not cls.all_trees:
5957
if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)):
6058
Y = cls.Y.eval()
6159
else:
@@ -142,8 +140,9 @@ def __new__(
142140
"Options linear and mix are experimental and still not well tested\n"
143141
+ "Use with caution."
144142
)
143+
# Create a unique manager list for each BART instance
145144
manager = Manager()
146-
cls.all_trees = manager.list()
145+
instance_all_trees = manager.list()
147146

148147
X, Y = preprocess_xy(X, Y)
149148

@@ -154,7 +153,7 @@ def __new__(
154153
(BARTRV,),
155154
{
156155
"name": "BART",
157-
"all_trees": cls.all_trees,
156+
"all_trees": instance_all_trees, # Instance-specific tree storage
158157
"inplace": False,
159158
"initval": Y.mean(),
160159
"X": X,

pymc_bart/pgbart.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def __init__( # noqa: PLR0912, PLR0915
130130
model: Optional[Model] = None,
131131
initial_point: PointType | None = None,
132132
compile_kwargs: dict | None = None,
133+
**kwargs, # Accept additional kwargs for compound sampling
133134
) -> None:
134135
model = modelcontext(model)
135136
if initial_point is None:
@@ -143,7 +144,24 @@ def __init__( # noqa: PLR0912, PLR0915
143144
if vars is None:
144145
raise ValueError("Unable to find variables to sample")
145146

146-
value_bart = vars[0]
147+
# Filter to only BART variables
148+
bart_vars = []
149+
for var in vars:
150+
rv = model.values_to_rvs.get(var)
151+
if rv is not None and isinstance(rv.owner.op, BARTRV):
152+
bart_vars.append(var)
153+
154+
if not bart_vars:
155+
raise ValueError("No BART variables found in the provided variables")
156+
157+
if len(bart_vars) > 1:
158+
raise ValueError(
159+
"PGBART can only handle one BART variable at a time. "
160+
"For multiple BART variables, PyMC will automatically create "
161+
"separate PGBART samplers for each variable."
162+
)
163+
164+
value_bart = bart_vars[0]
147165
self.bart = model.values_to_rvs[value_bart].owner.op
148166

149167
if isinstance(self.bart.X, Variable):
@@ -227,15 +245,15 @@ def __init__( # noqa: PLR0912, PLR0915
227245

228246
self.num_particles = num_particles
229247
self.indices = list(range(1, num_particles))
230-
shared = make_shared_replacements(initial_point, vars, model)
231-
self.likelihood_logp = logp(initial_point, [model.datalogp], vars, shared)
248+
shared = make_shared_replacements(initial_point, [value_bart], model)
249+
self.likelihood_logp = logp(initial_point, [model.datalogp], [value_bart], shared)
232250
self.all_particles = [
233251
[ParticleTree(self.a_tree) for _ in range(self.m)] for _ in range(self.trees_shape)
234252
]
235253
self.all_trees = np.array([[p.tree for p in pl] for pl in self.all_particles])
236254
self.lower = 0
237255
self.iter = 0
238-
super().__init__(vars, shared)
256+
super().__init__([value_bart], shared)
239257

240258
def astep(self, _):
241259
variable_inclusion = np.zeros(self.num_variates, dtype="int")
@@ -346,7 +364,7 @@ def resample(
346364
new_particles.append(particles[idx].copy())
347365
else:
348366
new_particles.append(particles[idx])
349-
seen.append(idx)
367+
seen.append(int(idx))
350368

351369
particles[1:] = new_particles
352370

tests/test_bart.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,70 @@ def test_categorical_model(separate_trees, split_rule):
256256
# Fit should be good enough so right category is selected over 50% of time
257257
assert (idata.predictions.y.median(["chain", "draw"]) == Y).all()
258258
assert pmb.compute_variable_importance(idata, bartrv=lo, X=X)["preds"].shape == (5, 50, 9, 3)
259+
260+
261+
def test_multiple_bart_variables():
262+
"""Test that multiple BART variables can coexist in a single model."""
263+
X1 = np.random.normal(0, 1, size=(50, 2))
264+
X2 = np.random.normal(0, 1, size=(50, 3))
265+
Y = np.random.normal(0, 1, size=50)
266+
267+
# Create correlated responses
268+
Y1 = X1[:, 0] + np.random.normal(0, 0.1, size=50)
269+
Y2 = X2[:, 0] + X2[:, 1] + np.random.normal(0, 0.1, size=50)
270+
271+
with pm.Model() as model:
272+
# Two separate BART variables with different covariates
273+
mu1 = pmb.BART("mu1", X1, Y1, m=5)
274+
mu2 = pmb.BART("mu2", X2, Y2, m=5)
275+
276+
# Combined model
277+
sigma = pm.HalfNormal("sigma", 1)
278+
y = pm.Normal("y", mu1 + mu2, sigma, observed=Y)
279+
280+
# Sample with automatic assignment of BART samplers
281+
idata = pm.sample(tune=50, draws=50, chains=1, random_seed=3415)
282+
283+
# Verify both BART variables have their own tree collections
284+
assert hasattr(mu1.owner.op, "all_trees")
285+
assert hasattr(mu2.owner.op, "all_trees")
286+
287+
# Verify trees are stored separately (different object references)
288+
assert mu1.owner.op.all_trees is not mu2.owner.op.all_trees
289+
290+
# Verify sampling worked
291+
assert idata.posterior["mu1"].shape == (1, 50, 50)
292+
assert idata.posterior["mu2"].shape == (1, 50, 50)
293+
294+
295+
def test_multiple_bart_variables_manual_step():
296+
"""Test that multiple BART variables work with manually assigned PGBART samplers."""
297+
X1 = np.random.normal(0, 1, size=(30, 2))
298+
X2 = np.random.normal(0, 1, size=(30, 2))
299+
Y = np.random.normal(0, 1, size=30)
300+
301+
# Create simple responses
302+
Y1 = X1[:, 0] + np.random.normal(0, 0.1, size=30)
303+
Y2 = X2[:, 1] + np.random.normal(0, 0.1, size=30)
304+
305+
with pm.Model() as model:
306+
# Two separate BART variables
307+
mu1 = pmb.BART("mu1", X1, Y1, m=3)
308+
mu2 = pmb.BART("mu2", X2, Y2, m=3)
309+
310+
# Non-BART variable
311+
sigma = pm.HalfNormal("sigma", 1)
312+
y = pm.Normal("y", mu1 + mu2, sigma, observed=Y)
313+
314+
# Manually create PGBART samplers for each BART variable
315+
step1 = pmb.PGBART([mu1], num_particles=5)
316+
step2 = pmb.PGBART([mu2], num_particles=5)
317+
318+
# Sample with manual step assignment
319+
idata = pm.sample(tune=20, draws=20, chains=1, step=[step1, step2], random_seed=3415)
320+
321+
# Verify both variables were sampled
322+
assert "mu1" in idata.posterior
323+
assert "mu2" in idata.posterior
324+
assert idata.posterior["mu1"].shape == (1, 20, 30)
325+
assert idata.posterior["mu2"].shape == (1, 20, 30)

0 commit comments

Comments
 (0)