Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 68 additions & 3 deletions pymc/logprob/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import numpy as np
import pytensor.tensor as pt

from pytensor import clone_replace
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scan.op import Scan
Expand Down Expand Up @@ -103,6 +104,38 @@ def convert_outer_out_to_in(
"""
output_scan_args = copy(input_scan_args)
inner_outs_to_new_inner_ins = {}
inner_var_replacements: dict[TensorVariable, TensorVariable] = {}

def ensure_inner_matches_outer_broadcastable(
inner_var: TensorVariable, outer_in_var: TensorVariable
) -> TensorVariable:
"""Return a nominal inner variable matching `outer_in_var` broadcastability.

`Scan.make_node` checks that outer sequence inputs and inner sequence
placeholders agree on static broadcastability (excluding the time axis).
Observed/value variables can introduce extra broadcastable axes (size-1
dimensions), so we may need to upgrade the inner placeholders.

We must not use `pt.specify_broadcastable` here, because that creates an
Apply node and the resulting variable is no longer a valid nominal inner
input (leading to MissingInputError when Scan builds its nominal graph).
"""
inner_var = inner_var_replacements.get(inner_var, inner_var)
if outer_in_var.type.ndim != inner_var.type.ndim + 1:
return inner_var
new_shape = list(inner_var.type.shape)
changed = False
for axis in range(inner_var.type.ndim):
if outer_in_var.type.broadcastable[axis + 1] and not inner_var.type.broadcastable[axis]:
new_shape[axis] = 1
changed = True
if not changed:
return inner_var
new_type = inner_var.type.clone(shape=tuple(new_shape))
new_inner = new_type()
new_inner.name = inner_var.name
inner_var_replacements[inner_var] = new_inner
return new_inner

# Map inner-outputs to outer-outputs
old_inner_outs_to_outer_outs = {}
Expand Down Expand Up @@ -141,10 +174,13 @@ def convert_outer_out_to_in(
var_idx = inner_out_info.index

# The old inner-output variable becomes the a new inner-input
new_inner_in_var = old_inner_out_var.clone()
new_inner_in_var = old_inner_out_var.type()
if new_inner_in_var.name:
new_inner_in_var.name = f"{new_inner_in_var.name}_vv"

outer_in_var = new_outer_input_vars[oo_var]
new_inner_in_var = ensure_inner_matches_outer_broadcastable(new_inner_in_var, outer_in_var)

inner_outs_to_new_inner_ins[old_inner_out_var] = new_inner_in_var

# We want to remove elements from both lists and tuples, because the
Expand All @@ -159,14 +195,21 @@ def remove(x, i):
inner_in_mit_sot_var = cast(
tuple[int, ...], tuple(output_scan_args.inner_in_mit_sot[var_idx])
)
inner_in_mit_sot_var = tuple(
ensure_inner_matches_outer_broadcastable(v, outer_in_var)
for v in inner_in_mit_sot_var
)
new_inner_in_seqs = (*inner_in_mit_sot_var, new_inner_in_var)
new_inner_in_mit_sot = remove(output_scan_args.inner_in_mit_sot, var_idx)
new_outer_in_mit_sot = remove(output_scan_args.outer_in_mit_sot, var_idx)
new_inner_in_sit_sot = tuple(output_scan_args.inner_in_sit_sot)
new_outer_in_sit_sot = tuple(output_scan_args.outer_in_sit_sot)
add_nit_sot = True
elif inner_out_info.name.endswith("sit_sot"):
new_inner_in_seqs = (output_scan_args.inner_in_sit_sot[var_idx], new_inner_in_var)
prev_inner_in_var = ensure_inner_matches_outer_broadcastable(
output_scan_args.inner_in_sit_sot[var_idx], outer_in_var
)
new_inner_in_seqs = (prev_inner_in_var, new_inner_in_var)
new_inner_in_sit_sot = remove(output_scan_args.inner_in_sit_sot, var_idx)
new_outer_in_sit_sot = remove(output_scan_args.outer_in_sit_sot, var_idx)
new_inner_in_mit_sot = tuple(output_scan_args.inner_in_mit_sot)
Expand Down Expand Up @@ -251,6 +294,9 @@ def remove(x, i):
+ output_scan_args.inner_out_nit_sot
)
traced_outs = replace_rvs_by_values(traced_outs, rvs_to_values=remapped_io_to_ii)

if inner_var_replacements:
traced_outs = clone_replace(traced_outs, replace=inner_var_replacements)
# Update output mappings
n_mit_sot = len(output_scan_args.inner_out_mit_sot)
output_scan_args.inner_out_mit_sot = traced_outs[:n_mit_sot]
Expand Down Expand Up @@ -345,7 +391,26 @@ def logprob_scan(op, values, *inputs, name=None, **kwargs):
# We will replace it by Join(axis=0, initial_value, value)
initval = get_initval_from_scan_tap_input(inp)
idx = outer_rvs.index(out)
values[idx] = pt.join(0, initval, values[idx])

# Observed/value variables can carry extra *static* broadcastability
# information (e.g. axis length 1 -> broadcastable True). When that
# happens, scan's internal consistency checks may fail during the
# construction of the logprob scan.
#
# We make the initial value at least as broadcastable as the value
# variable (excluding the time axis added by scan), which is safe
# whenever that axis is statically known to be broadcastable.
value = values[idx]
if value.type.ndim == initval.type.ndim + 1:
extra_bcast_axes = [
axis
for axis in range(initval.type.ndim)
if value.type.broadcastable[axis + 1] and not initval.type.broadcastable[axis]
]
if extra_bcast_axes:
initval = pt.specify_broadcastable(initval, *extra_bcast_axes)

values[idx] = pt.join(0, initval, value)

value_map = dict(zip(outer_rvs, values))

Expand Down
27 changes: 27 additions & 0 deletions tests/distributions/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,33 @@ def normal_shifted(mu, size):
expected_logp,
)

def test_scan_logprob_observed_broadcastable_axis(self):
def GRW(y_init, size=None):
def grw_step(y_tm1):
y = Normal.dist(mu=y_tm1)
return y, collect_default_updates([y])

n_steps = 10 if rv_size_is_none(size) else size[0]
y_hat, _updates = pytensor.scan(fn=grw_step, outputs_info=[y_init], n_steps=n_steps)
return y_hat

coords = {
"date": range(10),
"item": [1],
}
with Model(coords=coords) as m:
y0 = Normal("y0", 0, 0.1, dims=["item"])
CustomDist(
"y_hat",
y0,
dist=GRW,
dims=["date", "item"],
observed=np.ones((10, 1)),
)
logp_graph = m.logp()
assert isinstance(logp_graph, pt.TensorVariable)
assert logp_graph.ndim == 0

def test_explicit_rng(self):
def custom_dist(mu, size):
return Normal.dist(mu, size=size)
Expand Down