Skip to content

Fix progress bar error when nested CompoundStep samplers are assigned #7730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pymc/backends/ndarray.py
Original file line number Diff line number Diff line change
@@ -113,6 +113,10 @@ def record(self, point, sampler_stats=None) -> None:
if sampler_stats is not None:
for data, vars in zip(self._stats, sampler_stats):
for key, val in vars.items():
# step_meta is a key used by the progress bars to track which draw came from which step instance. It
# should never be stored as a sampler statistic.
if key == "step_meta":
continue
data[key][draw_idx] = val
elif self._stats is not None:
raise ValueError("Expected sampler_stats")
11 changes: 10 additions & 1 deletion pymc/sampling/mcmc.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@
"""Functions for MCMC sampling."""

import contextlib
import itertools
import logging
import pickle
import sys
@@ -111,6 +112,7 @@ def instantiate_steppers(
step_kwargs: dict[str, dict] | None = None,
initial_point: PointType | None = None,
compile_kwargs: dict | None = None,
step_id_generator: Iterator[int] | None = None,
) -> Step | list[Step]:
"""Instantiate steppers assigned to the model variables.

@@ -139,6 +141,9 @@ def instantiate_steppers(
if step_kwargs is None:
step_kwargs = {}

if step_id_generator is None:
step_id_generator = itertools.count()

used_keys = set()
if selected_steps:
if initial_point is None:
@@ -154,6 +159,7 @@ def instantiate_steppers(
model=model,
initial_point=initial_point,
compile_kwargs=compile_kwargs,
step_id_generator=step_id_generator,
**kwargs,
)
steps.append(step)
@@ -853,16 +859,19 @@ def joined_blas_limiter():
initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed_list)]

# Instantiate automatically selected steps
# Use a counter to generate a unique id for each stepper used in the model.
step_id_generator = itertools.count()
step = instantiate_steppers(
model,
steps=provided_steps,
selected_steps=selected_steps,
step_kwargs=kwargs,
initial_point=initial_points[0],
compile_kwargs=compile_kwargs,
step_id_generator=step_id_generator,
)
if isinstance(step, list):
step = CompoundStep(step)
step = CompoundStep(step, step_id_generator=step_id_generator)

if var_names is not None:
trace_vars = [v for v in model.unobserved_RVs if v.name in var_names]
37 changes: 32 additions & 5 deletions pymc/step_methods/arraystep.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
# limitations under the License.

from abc import abstractmethod
from collections.abc import Callable
from collections.abc import Callable, Iterator
from typing import cast

import numpy as np
@@ -43,14 +43,25 @@ class ArrayStep(BlockedStep):
:py:func:`pymc.util.get_random_generator` for more information.
"""

def __init__(self, vars, fs, allvars=False, blocked=True, rng: RandomGenerator = None):
def __init__(
self,
vars,
fs,
allvars=False,
blocked=True,
rng: RandomGenerator = None,
step_id_generator: Iterator[int] | None = None,
):
self.vars = vars
self.fs = fs
self.allvars = allvars
self.blocked = blocked
self.rng = get_random_generator(rng)
self._step_id = next(step_id_generator) if step_id_generator else None

def step(self, point: PointType) -> tuple[PointType, StatsType]:
def step(
self, point: PointType, step_parent_id: int | None = None
) -> tuple[PointType, StatsType]:
partial_funcs_and_point: list[Callable | PointType] = [
DictToArrayBijection.mapf(x, start_point=point) for x in self.fs
]
@@ -61,6 +72,9 @@ def step(self, point: PointType) -> tuple[PointType, StatsType]:
apoint = DictToArrayBijection.map(var_dict)
apoint_new, stats = self.astep(apoint, *partial_funcs_and_point)

for sts in stats:
sts["step_meta"] = {"step_id": self._step_id, "step_parent_id": step_parent_id}

if not isinstance(apoint_new, RaveledVars):
# We assume that the mapping has stayed the same
apoint_new = RaveledVars(apoint_new, apoint.point_map_info)
@@ -84,7 +98,14 @@ class ArrayStepShared(BlockedStep):
and unmapping overhead as well as moving fewer variables around.
"""

def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None):
def __init__(
self,
vars,
shared,
blocked=True,
rng: RandomGenerator = None,
step_id_generator: Iterator[int] | None = None,
):
"""
Create the ArrayStepShared object.

@@ -103,8 +124,11 @@ def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None):
self.shared = {get_var_name(var): shared for var, shared in shared.items()}
self.blocked = blocked
self.rng = get_random_generator(rng)
self._step_id = next(step_id_generator) if step_id_generator else None

def step(self, point: PointType) -> tuple[PointType, StatsType]:
def step(
self, point: PointType, step_parent_id: int | None = None
) -> tuple[PointType, StatsType]:
full_point = None
if self.shared:
for name, shared_var in self.shared.items():
@@ -115,6 +139,9 @@ def step(self, point: PointType) -> tuple[PointType, StatsType]:
q = DictToArrayBijection.map(point)
apoint, stats = self.astep(q)

for sts in stats:
sts["step_meta"] = {"step_id": self._step_id, "step_parent_id": step_parent_id}

if not isinstance(apoint, RaveledVars):
# We assume that the mapping has stayed the same
apoint = RaveledVars(apoint, q.point_map_info)
Loading