Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions pymc/logprob/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from pytensor.tensor.type_other import NoneConst, NoneTypeT, SliceType
from pytensor.tensor.variable import TensorVariable

from pymc.exceptions import NotConstantValueError
from pymc.logprob.abstract import (
MeasurableElemwise,
MeasurableOp,
Expand All @@ -82,6 +83,7 @@
)
from pymc.logprob.utils import (
check_potential_measurability,
dirac_delta,
filter_measurable_variables,
get_related_valued_nodes,
)
Expand Down Expand Up @@ -414,21 +416,58 @@ def find_measurable_switch_mixture(fgraph, node):

switch_cond, *components = node.inputs

# require at least one measurable component, otherwise there's no logprob to compute
measurable_components = filter_measurable_variables(components)
if not measurable_components:
return None

# only allow non measurable components if they are compile time constants
measurable_ids = {id(c) for c in measurable_components}
non_measurable_components = [c for c in components if id(c) not in measurable_ids]
folded_constants: dict[int, TensorVariable] = {}
for comp in non_measurable_components:
if isinstance(comp, Constant):
folded_constants[id(comp)] = comp
continue
try:
(folded_comp,) = constant_fold([comp], raise_not_constant=True)
except NotConstantValueError:
return None
if not isinstance(folded_comp, TensorVariable):
folded_comp = pt.constant(folded_comp)
if not isinstance(folded_comp, Constant):
return None
folded_constants[id(comp)] = folded_comp
bcast_ref = measurable_components[
0
] # use a measurable component as broadcasting reference for constant branches
new_components: list[TensorVariable] = []
for comp in components:
if id(comp) in measurable_ids:
new_components.append(comp)
else:
# treat constant branches as point masses so we can compute a logp.
# broadcasting is allowed for constants because it doesn't introduce dependence.
const_comp = folded_constants[id(comp)]
bcast_comp, _ = pt.broadcast_arrays(const_comp, bcast_ref)
new_components.append(dirac_delta(bcast_comp))

# We don't support broadcasting of components, as that yields dependent (identical) values.
# The current logp implementation assumes all component values are independent.
# Broadcasting of the switch condition is fine
out_bcast = node.outputs[0].type.broadcastable
if any(comp.type.broadcastable != out_bcast for comp in components):
return None

if set(filter_measurable_variables(components)) != set(components):
if any(
comp.type.broadcastable != out_bcast
for comp in new_components
if id(comp) in measurable_ids
):
return None

# Check that `switch_cond` is not potentially measurable
if check_potential_measurability([switch_cond]):
return None

return [measurable_switch_mixture(switch_cond, *components)]
return [measurable_switch_mixture(switch_cond, *new_components)]


@_logprob.register(MeasurableSwitchMixture)
Expand Down
26 changes: 24 additions & 2 deletions pymc/logprob/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

from numpy.lib.array_utils import normalize_axis_index
from pytensor import tensor as pt
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor import TensorVariable
Expand All @@ -65,6 +66,7 @@
)
from pymc.logprob.utils import (
check_potential_measurability,
dirac_delta,
filter_measurable_variables,
get_related_valued_nodes,
replace_rvs_by_values,
Expand Down Expand Up @@ -162,9 +164,29 @@ def find_measurable_stacks(fgraph, node) -> list[TensorVariable] | None:
else:
base_vars = node.inputs

if not all(check_potential_measurability([base_var]) for base_var in base_vars):
# allow mixing potentially measurable inputs with compile time constants.
new_base_vars: list[TensorVariable] = []
has_measurable = False
for base_var in base_vars:
if check_potential_measurability([base_var]):
has_measurable = True
new_base_vars.append(base_var)
else:
if isinstance(base_var, Constant):
folded_var = base_var
else:
try:
(folded_var,) = constant_fold([base_var], raise_not_constant=True)
except NotConstantValueError:
return None
if not isinstance(folded_var, TensorVariable):
folded_var = pt.constant(folded_var)
if not isinstance(folded_var, Constant):
return None
new_base_vars.append(dirac_delta(folded_var))
if not has_measurable:
return None

base_vars = new_base_vars
base_vars = assume_valued_outputs(base_vars)
if not all(var.owner and isinstance(var.owner.op, MeasurableOp) for var in base_vars):
return None
Expand Down
18 changes: 18 additions & 0 deletions tests/logprob/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
as_index_constant,
)

import pymc as pm

from pymc.logprob.abstract import MeasurableOp
from pymc.logprob.basic import conditional_logp, logp
from pymc.logprob.mixture import MeasurableSwitchMixture, expand_indices
Expand Down Expand Up @@ -971,6 +973,22 @@ def test_switch_mixture_invalid_bcast():
assert not isinstance(fgraph.outputs[0].owner.inputs[0].owner.op, MeasurableOp)


def test_switch_mixture_constant_branch_broadcast_ok():
t = pt.arange(10)
cat = pm.Categorical.dist(p=[0.5, 0.5], shape=(10,))
cat_fixed_const = pt.where(t > 5, cat, -1)
cat_fixed_dirac = pt.where(t > 5, cat, pm.DiracDelta.dist(-1, shape=cat.shape))
vv_const = cat_fixed_const.clone()
vv_dirac = cat_fixed_dirac.clone()
logp_const = logp(cat_fixed_const, vv_const, warn_rvs=False)
logp_dirac = logp(cat_fixed_dirac, vv_dirac, warn_rvs=False)
test_value = np.where(np.arange(10) > 5, 0, -1).astype(vv_const.dtype)
np.testing.assert_allclose(
logp_const.eval({vv_const: test_value}),
logp_dirac.eval({vv_dirac: test_value.astype(vv_dirac.dtype)}),
)


def test_ifelse_mixture_one_component():
if_rv = pt.random.bernoulli(0.5, name="if")
scale_rv = pt.random.halfnormal(name="scale")
Expand Down
49 changes: 49 additions & 0 deletions tests/logprob/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,30 @@ def test_measurable_make_vector():
assert np.isclose(make_vector_logp_eval.sum(), ref_logp_eval_eval)


def test_measurable_make_vector_with_constant_input():
base1_rv = pt.random.normal(name="base1")
base2_rv = pt.random.halfnormal(name="base2")
y_rv = pt.stack((base1_rv, pt.constant(0.0), base2_rv))
y_rv.name = "y"
base1_vv = base1_rv.clone()
base2_vv = base2_rv.clone()
y_vv = y_rv.clone()
ref_logp = conditional_logp({base1_rv: base1_vv, base2_rv: base2_vv})
ref_logp_combined = pt.sum([pt.sum(factor) for factor in ref_logp.values()])
y_logp = logp(y_rv, y_vv)
base1_testval = base1_rv.eval()
base2_testval = base2_rv.eval()
y_testval = np.stack((base1_testval, 0.0, base2_testval)).astype(y_vv.dtype)
ref_logp_eval = ref_logp_combined.eval({base1_vv: base1_testval, base2_vv: base2_testval})
y_logp_eval = y_logp.eval({y_vv: y_testval})
assert y_logp_eval.shape == y_testval.shape
assert np.isclose(y_logp_eval.sum(), ref_logp_eval)
y_testval_bad = y_testval.copy()
y_testval_bad[1] = 1.0
y_logp_eval_bad = y_logp.eval({y_vv: y_testval_bad})
assert y_logp_eval_bad[1] == -np.inf


@pytest.mark.parametrize("reverse", (False, True))
def test_measurable_make_vector_interdependent(reverse):
"""Test that we can obtain a proper graph when stacked RVs depend on each other"""
Expand Down Expand Up @@ -190,6 +214,31 @@ def test_measurable_join_interdependent(reverse):
)


def test_measurable_join_with_constant_input():
base1_rv = pt.random.normal(size=(2,), name="base1")
base2_rv = pt.random.exponential(size=(3,), name="base2")
const = pt.constant(np.array([0.0, 0.0, 0.0]))
y_rv = pt.join(0, base1_rv, const, base2_rv)
y_rv.name = "y"
base1_vv = base1_rv.clone()
base2_vv = base2_rv.clone()
y_vv = y_rv.clone()
ref_logp = conditional_logp({base1_rv: base1_vv, base2_rv: base2_vv})
ref_logp_combined = pt.sum([pt.sum(factor) for factor in ref_logp.values()])
y_logp = logp(y_rv, y_vv)
base1_testval = base1_rv.eval()
base2_testval = base2_rv.eval()
y_testval = np.concatenate([base1_testval, np.zeros(3), base2_testval]).astype(y_vv.dtype)
ref_logp_eval = ref_logp_combined.eval({base1_vv: base1_testval, base2_vv: base2_testval})
y_logp_eval = y_logp.eval({y_vv: y_testval})
assert y_logp_eval.shape == y_testval.shape
assert np.isclose(y_logp_eval.sum(), ref_logp_eval)
y_testval_bad = y_testval.copy()
y_testval_bad[2] = 1.0
y_logp_eval_bad = y_logp.eval({y_vv: y_testval_bad})
assert y_logp_eval_bad[2] == -np.inf


@pytest.mark.parametrize(
"size1, size2, axis, concatenate",
[
Expand Down