Skip to content

Add support for multiple BART random variables per model. #231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions pymc_bart/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from pytensor.tensor.variable import TensorVariable

from .split_rules import SplitRule
from .tree import Tree
from .utils import TensorLike, _sample_posterior

__all__ = ["BART"]
Expand All @@ -42,7 +41,6 @@ class BARTRV(RandomVariable):
signature = "(m,n),(m),(),(),() -> (m)"
dtype: str = "floatX"
_print_name: tuple[str, str] = ("BART", "\\operatorname{BART}")
all_trees = list[list[list[Tree]]]

def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): # pylint: disable=arguments-renamed
idx = dist_params[0].ndim - 2
Expand All @@ -55,7 +53,7 @@ def rng_fn( # pylint: disable=W0237
if not size:
size = None

if not cls.all_trees:
if not hasattr(cls, "all_trees") or not cls.all_trees:
if isinstance(cls.Y, (TensorSharedVariable, TensorVariable)):
Y = cls.Y.eval()
else:
Expand Down Expand Up @@ -142,8 +140,9 @@ def __new__(
"Options linear and mix are experimental and still not well tested\n"
+ "Use with caution."
)
# Create a unique manager list for each BART instance
manager = Manager()
cls.all_trees = manager.list()
instance_all_trees = manager.list()

X, Y = preprocess_xy(X, Y)

Expand All @@ -154,7 +153,7 @@ def __new__(
(BARTRV,),
{
"name": "BART",
"all_trees": cls.all_trees,
"all_trees": instance_all_trees, # Instance-specific tree storage
"inplace": False,
"initval": Y.mean(),
"X": X,
Expand Down
33 changes: 29 additions & 4 deletions pymc_bart/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__( # noqa: PLR0912, PLR0915
model: Optional[Model] = None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
**kwargs, # Accept additional kwargs for compound sampling
) -> None:
model = modelcontext(model)
if initial_point is None:
Expand All @@ -143,7 +144,24 @@ def __init__( # noqa: PLR0912, PLR0915
if vars is None:
raise ValueError("Unable to find variables to sample")

value_bart = vars[0]
# Filter to only BART variables
bart_vars = []
for var in vars:
rv = model.values_to_rvs.get(var)
if rv is not None and isinstance(rv.owner.op, BARTRV):
bart_vars.append(var)

if not bart_vars:
raise ValueError("No BART variables found in the provided variables")

if len(bart_vars) > 1:
raise ValueError(
"PGBART can only handle one BART variable at a time. "
"For multiple BART variables, PyMC will automatically create "
"separate PGBART samplers for each variable."
)

value_bart = bart_vars[0]
self.bart = model.values_to_rvs[value_bart].owner.op

if isinstance(self.bart.X, Variable):
Expand Down Expand Up @@ -227,15 +245,15 @@ def __init__( # noqa: PLR0912, PLR0915

self.num_particles = num_particles
self.indices = list(range(1, num_particles))
shared = make_shared_replacements(initial_point, vars, model)
self.likelihood_logp = logp(initial_point, [model.datalogp], vars, shared)
shared = make_shared_replacements(initial_point, [value_bart], model)
self.likelihood_logp = logp(initial_point, [model.datalogp], [value_bart], shared)
self.all_particles = [
[ParticleTree(self.a_tree) for _ in range(self.m)] for _ in range(self.trees_shape)
]
self.all_trees = np.array([[p.tree for p in pl] for pl in self.all_particles])
self.lower = 0
self.iter = 0
super().__init__(vars, shared)
super().__init__([value_bart], shared)

def astep(self, _):
variable_inclusion = np.zeros(self.num_variates, dtype="int")
Expand Down Expand Up @@ -408,6 +426,13 @@ def competence(var: pm.Distribution, has_grad: bool) -> Competence:
return Competence.IDEAL
return Competence.INCOMPATIBLE

@staticmethod
def _make_update_stats_functions():
def update_stats(step_stats):
return {key: step_stats[key] for key in ("variable_inclusion", "tune")}

return (update_stats,)


class RunningSd:
"""Welford's online algorithm for computing the variance/standard deviation"""
Expand Down
67 changes: 67 additions & 0 deletions tests/test_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,3 +256,70 @@ def test_categorical_model(separate_trees, split_rule):
# Fit should be good enough so right category is selected over 50% of time
assert (idata.predictions.y.median(["chain", "draw"]) == Y).all()
assert pmb.compute_variable_importance(idata, bartrv=lo, X=X)["preds"].shape == (5, 50, 9, 3)


def test_multiple_bart_variables():
"""Test that multiple BART variables can coexist in a single model."""
X1 = np.random.normal(0, 1, size=(50, 2))
X2 = np.random.normal(0, 1, size=(50, 3))
Y = np.random.normal(0, 1, size=50)

# Create correlated responses
Y1 = X1[:, 0] + np.random.normal(0, 0.1, size=50)
Y2 = X2[:, 0] + X2[:, 1] + np.random.normal(0, 0.1, size=50)

with pm.Model() as model:
# Two separate BART variables with different covariates
mu1 = pmb.BART("mu1", X1, Y1, m=5)
mu2 = pmb.BART("mu2", X2, Y2, m=5)

# Combined model
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu1 + mu2, sigma, observed=Y)

# Sample with automatic assignment of BART samplers
idata = pm.sample(tune=50, draws=50, chains=1, random_seed=3415)

# Verify both BART variables have their own tree collections
assert hasattr(mu1.owner.op, "all_trees")
assert hasattr(mu2.owner.op, "all_trees")

# Verify trees are stored separately (different object references)
assert mu1.owner.op.all_trees is not mu2.owner.op.all_trees

# Verify sampling worked
assert idata.posterior["mu1"].shape == (1, 50, 50)
assert idata.posterior["mu2"].shape == (1, 50, 50)


def test_multiple_bart_variables_manual_step():
"""Test that multiple BART variables work with manually assigned PGBART samplers."""
X1 = np.random.normal(0, 1, size=(30, 2))
X2 = np.random.normal(0, 1, size=(30, 2))
Y = np.random.normal(0, 1, size=30)

# Create simple responses
Y1 = X1[:, 0] + np.random.normal(0, 0.1, size=30)
Y2 = X2[:, 1] + np.random.normal(0, 0.1, size=30)

with pm.Model() as model:
# Two separate BART variables
mu1 = pmb.BART("mu1", X1, Y1, m=3)
mu2 = pmb.BART("mu2", X2, Y2, m=3)

# Non-BART variable
sigma = pm.HalfNormal("sigma", 1)
y = pm.Normal("y", mu1 + mu2, sigma, observed=Y)

# Manually create PGBART samplers for each BART variable
step1 = pmb.PGBART([mu1], num_particles=5)
step2 = pmb.PGBART([mu2], num_particles=5)

# Sample with manual step assignment
idata = pm.sample(tune=20, draws=20, chains=1, step=[step1, step2], random_seed=3415)

# Verify both variables were sampled
assert "mu1" in idata.posterior
assert "mu2" in idata.posterior
assert idata.posterior["mu1"].shape == (1, 20, 30)
assert idata.posterior["mu2"].shape == (1, 20, 30)
Loading