Skip to content

Commit 83953eb

Browse files
committed
Add Beta-Binomial conjugacy optimization
1 parent e051965 commit 83953eb

File tree

9 files changed

+365
-16
lines changed

9 files changed

+365
-16
lines changed

pymc_experimental/model/marginal/distributions.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pymc.logprob.abstract import MeasurableOp, _logprob
88
from pymc.logprob.basic import conditional_logp, logp
99
from pymc.pytensorf import constant_fold
10-
from pytensor import Variable
1110
from pytensor.compile.builders import OpFromGraph
1211
from pytensor.compile.mode import Mode
1312
from pytensor.graph import Op, vectorize_graph
@@ -17,6 +16,7 @@
1716
from pytensor.tensor import TensorVariable
1817

1918
from pymc_experimental.distributions import DiscreteMarkovChain
19+
from pymc_experimental.utils.ofg import inline_ofg_outputs
2020

2121

2222
class MarginalRV(OpFromGraph, MeasurableOp):
@@ -126,18 +126,6 @@ def align_logp_dims(dims: tuple[tuple[int, None]], logp: TensorVariable) -> Tens
126126
return logp.transpose(*dims_alignment)
127127

128128

129-
def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
130-
"""Inline the inner graph (outputs) of an OpFromGraph Op.
131-
132-
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
133-
the inner graph.
134-
"""
135-
return clone_replace(
136-
op.inner_outputs,
137-
replace=tuple(zip(op.inner_inputs, inputs)),
138-
)
139-
140-
141129
DUMMY_ZERO = pt.constant(0, name="dummy_zero")
142130

143131

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
# Add rewrites to the optimization DBs
2+
import pymc_experimental.sampling.optimizations.conjugacy
23
import pymc_experimental.sampling.optimizations.summary_stats

pymc_experimental/sampling/mcmc.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,16 @@ def opt_sample(
4747
y = pm.Binomial("y", n=10, p=p, observed=5)
4848
4949
idata = pmx.opt_sample(verbose=True)
50+
51+
# Applied optimization: beta_binomial_conjugacy 1x
52+
# ConjugateRVSampler: [p]
5053
"""
5154

5255
model = modelcontext(model)
5356
fgraph, _ = fgraph_from_model(model)
5457

5558
if rewriter is None:
56-
rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=["summary_stats"]))
59+
rewriter = posterior_optimization_db.query(RewriteDatabaseQuery(include=["summary_stats", "conjugacy"]))
5760
_, _, rewrite_counters, *_ = rewriter.rewrite(fgraph)
5861

5962
if verbose:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
from typing import Sequence
2+
3+
from pymc import STEP_METHODS
4+
from pytensor.tensor.random.type import RandomGeneratorType
5+
6+
from pytensor.compile.builders import OpFromGraph
7+
8+
from pymc_experimental.sampling.mcmc import posterior_optimization_db
9+
from pymc_experimental.sampling.optimizations.conjugate_sampler import ConjugateRV, ConjugateRVSampler
10+
11+
STEP_METHODS.append(ConjugateRVSampler)
12+
13+
from pytensor.graph.fg import Output
14+
from pytensor.tensor.elemwise import DimShuffle
15+
from pymc.model.fgraph import model_free_rv, ModelValuedVar
16+
17+
18+
from pytensor.graph.basic import Variable
19+
from pytensor.graph.fg import FunctionGraph
20+
from pytensor.graph.rewriting.basic import node_rewriter
21+
from pymc.model.fgraph import ModelFreeRV
22+
from pymc.distributions import Beta, Binomial
23+
from pymc.pytensorf import collect_default_updates
24+
25+
26+
def get_model_var_of_rv(fgraph: FunctionGraph, rv: Variable) -> Variable:
27+
"""Return the Model dummy var that wraps the RV"""
28+
for client, _ in fgraph.clients[rv]:
29+
if isinstance(client.op, ModelValuedVar):
30+
return client.outputs[0]
31+
32+
33+
def get_dist_params(rv: Variable) -> tuple[Variable]:
34+
return rv.owner.op.dist_params(rv.owner)
35+
36+
37+
def rv_used_by(fgraph: FunctionGraph, rv: Variable, used_by_type: type, used_as_arg_idx: int | Sequence[int], strict: bool = True) -> list[Variable]:
38+
"""Return the RVs that use `rv` as an argument in an operation of type `used_by_type`.
39+
40+
RV may be used directly or broadcasted before being used.
41+
42+
Parameters
43+
----------
44+
fgraph : FunctionGraph
45+
The function graph containing the RVs
46+
rv : Variable
47+
The RV to check for uses.
48+
used_by_type : type
49+
The type of operation that may use the RV.
50+
used_as_arg_idx : int | Sequence[int]
51+
The index of the RV in the operation's inputs.
52+
strict : bool, default=True
53+
If True, return no results when the RV is used in an unrecognized way.
54+
55+
"""
56+
if isinstance(used_as_arg_idx, int):
57+
used_as_arg_idx = (used_as_arg_idx,)
58+
59+
clients = fgraph.clients
60+
used_by : list[Variable] = []
61+
for client, inp_idx in clients[rv]:
62+
if isinstance(client.op, Output):
63+
continue
64+
65+
if isinstance(client.op, used_by_type) and inp_idx in used_as_arg_idx:
66+
# RV is directly used by the RV type
67+
used_by.append(client.default_output())
68+
69+
elif isinstance(client.op, DimShuffle) and client.op.is_left_expand_dims:
70+
for sub_client, sub_inp_idx in clients[client.outputs[0]]:
71+
if isinstance(sub_client.op, used_by_type) and sub_inp_idx in used_as_arg_idx:
72+
# RV is broadcasted and then used by the RV type
73+
used_by.append(sub_client.default_output())
74+
elif strict:
75+
# Some other unrecognized use, bail out
76+
return []
77+
elif strict:
78+
# Some other unrecognized use, bail out
79+
return []
80+
81+
return used_by
82+
83+
84+
def wrap_rv_and_conjugate_rv(fgraph: FunctionGraph, rv: Variable, conjugate_rv: Variable, inputs: Sequence[Variable]) -> Variable:
85+
"""Wrap the RV and its conjugate posterior RV in a ConjugateRV node.
86+
87+
Also takes care of handling the random number generators used in the conjugate posterior.
88+
"""
89+
rngs, next_rngs = zip(*collect_default_updates(conjugate_rv, inputs=[rv, *inputs]).items())
90+
for rng in rngs:
91+
if rng not in fgraph.inputs:
92+
fgraph.add_input(rng)
93+
conjugate_op = ConjugateRV(inputs=[rv, *inputs, *rngs], outputs=[rv, conjugate_rv, *next_rngs])
94+
return conjugate_op(rv, *inputs, *rngs)[0]
95+
96+
97+
def create_untransformed_free_rv(fgraph: FunctionGraph, rv: Variable, name: str, dims: Sequence[str | Variable]) -> Variable:
98+
"""Create a model FreeRV without transform."""
99+
transform = None
100+
value = rv.type(name=name)
101+
fgraph.add_input(value)
102+
free_rv = model_free_rv(rv, value, transform, *dims)
103+
free_rv.name = name
104+
return free_rv
105+
106+
107+
@node_rewriter(tracks=[ModelFreeRV])
108+
def beta_binomial_conjugacy(fgraph: FunctionGraph, node):
109+
"""This applies the equivalence (up to a normalizing constant) described in:
110+
111+
https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics
112+
"""
113+
[beta_free_rv] = node.outputs
114+
beta_rv, beta_value, *beta_dims = node.inputs
115+
116+
if not isinstance(beta_rv.owner.op, Beta):
117+
return None
118+
119+
p_arg_idx = 3 # inputs to Binomial are (rng, size, n, p)
120+
binomial_rvs = rv_used_by(fgraph, beta_free_rv, Binomial, p_arg_idx)
121+
122+
if len(binomial_rvs) != 1:
123+
# Question: Can we apply conjugacy when RV is used by more than one binomial?
124+
return None
125+
126+
[binomial_rv] = binomial_rvs
127+
128+
binomial_model_var = get_model_var_of_rv(fgraph, binomial_rv)
129+
if binomial_model_var is None:
130+
return None
131+
132+
# We want to replace free_rv by ConjugateRV()->(free_rv, conjugate_posterior_rv)
133+
a, b = get_dist_params(beta_rv)
134+
n, _ = get_dist_params(binomial_rv)
135+
136+
# Use value of y in new graph to avoid circularity
137+
y = binomial_model_var.owner.inputs[1]
138+
139+
conjugate_a = a + y
140+
conjugate_b = b + (n - y)
141+
extra_dims = range(binomial_rv.type.ndim - beta_rv.type.ndim)
142+
if extra_dims:
143+
conjugate_a = conjugate_a.sum(extra_dims)
144+
conjugate_b = conjugate_b.sum(extra_dims)
145+
conjugate_beta_rv = Beta.dist(conjugate_a, conjugate_b)
146+
147+
new_beta_rv = wrap_rv_and_conjugate_rv(fgraph, beta_rv, conjugate_beta_rv, [a, b, n, y])
148+
new_beta_free_rv = create_untransformed_free_rv(fgraph, new_beta_rv, beta_free_rv.name, beta_dims)
149+
return [new_beta_free_rv]
150+
151+
152+
posterior_optimization_db.register(
153+
beta_binomial_conjugacy.__name__,
154+
beta_binomial_conjugacy,
155+
"conjugacy"
156+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import numpy as np
2+
3+
from pymc_experimental.utils.ofg import inline_ofg_outputs
4+
from pytensor.compile.builders import OpFromGraph
5+
from pymc.logprob.abstract import MeasurableOp, _logprob
6+
from pymc.distributions.distribution import _support_point
7+
from pymc.step_methods.compound import BlockedStep, StepMethodState, Competence
8+
from pymc.model.core import modelcontext
9+
from pymc.util import get_value_vars_from_user_vars
10+
from pymc.pytensorf import compile_pymc
11+
from pytensor import shared
12+
from pytensor.tensor.random.type import RandomGeneratorType
13+
from pytensor.link.jax.linker import JAXLinker
14+
from pymc.initial_point import PointType
15+
16+
class ConjugateRV(OpFromGraph, MeasurableOp):
17+
"""Wrapper for ConjugateRVs, that outputs the original RV and the conjugate posterior expression.
18+
19+
For partial step samplers to work, the logp and initial point correspond to the original RV
20+
while the variable itself is sampled by default by the `ConjugateRVSampler` by evaluating directly the
21+
conjugate posterior expression (i.e., taking forward random draws).
22+
"""
23+
24+
25+
@_logprob.register(ConjugateRV)
26+
def conjugate_rv_logp(op, values, rv, *params, **kwargs):
27+
# Logp is the same as the original RV
28+
return _logprob(rv.owner.op, values, *rv.owner.inputs)
29+
30+
31+
@_support_point.register(ConjugateRV)
32+
def conjugate_rv_support_point(op, conjugate_rv, rv, *params):
33+
# Support point is the same as the original RV
34+
return _support_point(rv.owner.op, rv, *rv.owner.inputs)
35+
36+
37+
class ConjugateRVSampler(BlockedStep):
38+
name = "conjugate_rv_sampler"
39+
_state_class = StepMethodState
40+
41+
def __init__(self, vars, model=None, rng=None, compile_kwargs: dict | None = None, **kwargs):
42+
if len(vars) != 1:
43+
raise ValueError("ConjugateRVSampler can only be assigned to one variable at a time")
44+
45+
model = modelcontext(model)
46+
[value] = get_value_vars_from_user_vars(vars, model=model)
47+
rv = model.values_to_rvs[value]
48+
self.vars = (value,)
49+
self.rv_name = value.name
50+
51+
if model.rvs_to_transforms[rv] is not None:
52+
raise ValueError("Variable assigned to ConjugateRVSampler cannot be transformed")
53+
54+
rv_and_posterior_rv_node = rv.owner
55+
op = rv_and_posterior_rv_node.op
56+
if not isinstance(op, ConjugateRV):
57+
raise ValueError("Variable must be a ConjugateRV")
58+
59+
# Replace RVs in inputs of rv_posterior_rv_node by the corresponding value variables
60+
value_inputs = model.replace_rvs_by_values(
61+
[rv_and_posterior_rv_node.outputs[1]],
62+
)[0].owner.inputs
63+
# Inline the ConjugateRV graph to only compile `posterior_rv`
64+
_, posterior_rv, *_ = inline_ofg_outputs(op, value_inputs)
65+
66+
if compile_kwargs is None:
67+
compile_kwargs = {}
68+
self.posterior_fn = compile_pymc(
69+
model.value_vars,
70+
posterior_rv,
71+
random_seed=rng,
72+
on_unused_input="ignore",
73+
**compile_kwargs,
74+
)
75+
self.posterior_fn.trust_input = True
76+
if isinstance(self.posterior_fn.maker.linker, JAXLinker):
77+
# Reseeding RVs in JAX backend requires a different logic, becuase the SharedVariables
78+
# used internally are not the ones that `function.get_shared()` returns.
79+
raise ValueError("ConjugateRVSampler is not compatible with JAX backend")
80+
81+
def set_rng(self, rng: np.random.Generator):
82+
# Copy the function and replace any shared RNGs
83+
# This is needed so that it can work correctly with multiple traces
84+
# This will be costly if set_rng is called too often!
85+
shared_rngs = [
86+
var for var in self.posterior_fn.get_shared() if isinstance(var.type, RandomGeneratorType)
87+
]
88+
n_shared_rngs = len(shared_rngs)
89+
swap = {
90+
old_shared_rng: shared(rng, borrow=True)
91+
for old_shared_rng, rng in zip(shared_rngs, rng.spawn(n_shared_rngs), strict=True)
92+
}
93+
self.posterior_fn = self.posterior_fn.copy(swap=swap)
94+
95+
def step(self, point: PointType) -> tuple[PointType, list]:
96+
new_point = point.copy()
97+
new_point[self.rv_name] = self.posterior_fn(**point)
98+
return new_point, []
99+
100+
@staticmethod
101+
def competence(var, has_grad):
102+
"""BinaryMetropolis is only suitable for Bernoulli and Categorical variables with k=2."""
103+
if isinstance(var.owner.op, ConjugateRV):
104+
return Competence.IDEAL
105+
106+
return Competence.INCOMPATIBLE

pymc_experimental/utils/ofg.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from pytensor.graph.basic import Variable
2+
from pytensor.graph.replace import clone_replace
3+
from pytensor.compile.builders import OpFromGraph
4+
from typing import Sequence
5+
6+
7+
def inline_ofg_outputs(op: OpFromGraph, inputs: Sequence[Variable]) -> tuple[Variable]:
8+
"""Inline the inner graph (outputs) of an OpFromGraph Op.
9+
10+
Whereas `OpFromGraph` "wraps" a graph inside a single Op, this function "unwraps"
11+
the inner graph.
12+
"""
13+
return clone_replace(
14+
op.inner_outputs,
15+
replace=tuple(zip(op.inner_inputs, inputs)),
16+
)

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ addopts = [
88
]
99

1010
filterwarnings =[
11-
"error",
11+
# "error",
1212
# Raised by arviz when the model_builder class adds non-standard group names to InferenceData
1313
"ignore::UserWarning:arviz.data.inference_data",
1414

tests/sampling/mcmc/test_mcmc.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import logging
2+
13
import numpy as np
24
from pymc.model.core import Model
3-
from pymc.distributions import Normal, HalfNormal
5+
from pymc.distributions import Normal, HalfNormal, Beta, Binomial
46
from pymc.sampling.mcmc import sample
57

68
from pymc_experimental import opt_sample
@@ -27,3 +29,34 @@ def test_sample_opt_summary_stats(capsys):
2729
np.testing.assert_allclose(idata.posterior["mu"].mean(), opt_idata.posterior["mu"].mean(), rtol=1e-3)
2830
np.testing.assert_allclose(idata.posterior["sigma"].mean(), opt_idata.posterior["sigma"].mean(), rtol=1e-2)
2931
assert idata.sample_stats.sampling_time > opt_idata.sample_stats.sampling_time
32+
33+
34+
def test_sample_opt_conjugate(caplog, capsys):
35+
caplog.set_level(logging.INFO, logger="pymc")
36+
37+
with Model() as m:
38+
p = Beta("p", 1, 1)
39+
y = Binomial("y", n=100, p=p, observed=99)
40+
41+
idata = opt_sample(tune=0, chains=4, draws=250, progressbar=False, compute_convergence_checks=False, random_seed=0, verbose=True)
42+
43+
captured_out = capsys.readouterr().out
44+
assert "Applied optimization: beta_binomial_conjugacy 1x" in captured_out
45+
46+
# Test it used ConjugateRVSampler
47+
assert "ConjugateRVSampler: [p]" in caplog.text
48+
49+
np.testing.assert_allclose(idata.posterior["p"].mean(), 100/102, atol=1e-3)
50+
np.testing.assert_allclose(idata.posterior["p"].std(), np.sqrt(100*2/(102**2 * 103)), atol=1e-3)
51+
52+
# Draws are different across chains
53+
assert (np.diff(idata.posterior["p"].isel(draw=0).values) > 0).all()
54+
55+
# Check draws respect random_seed
56+
with m:
57+
new_idata = opt_sample(tune=0, chains=4, draws=1, progressbar=False, compute_convergence_checks=False, random_seed=0)
58+
np.testing.assert_allclose(idata.posterior["p"].isel(draw=0), new_idata.posterior["p"].isel(draw=0))
59+
60+
with m:
61+
new_idata = opt_sample(tune=0, chains=4, draws=1, progressbar=False, compute_convergence_checks=False, random_seed=1)
62+
assert not np.allclose(idata.posterior["p"].isel(draw=0), new_idata.posterior["p"].isel(draw=0))

0 commit comments

Comments
 (0)