Skip to content

Commit 0db176c

Browse files
Show one progress bar per chain when sampling (#7634)
* One progress bar per chain when samplings * Add guard against divide by zero when computing draws per second * No more purple * Step samplers are responsible for setting up progress bars * Fix typos * Add progressbar defaults to BlockedStep ABC * pre-commit * Only update NUTS divergence stats after tuning * Add `Elapsed` and `Remaining` columns * Remove green color when chain finishes * Create `ProgressManager` class to handle progress bars * Yield `stats` from `_iter_sample` * Use `ProgressManager` in `_sample_many` * pre-commit * Explicit case handling for `progressbar` argument * Allow all permutations of arguments to progressbar * Appease mypy * Add True case * Fix final count when `progress = "combined"` * Update docstrings * mypy + cleanup * Syntax error in typehint * Simplify progressbar choices, update docstring * Incorporate feedback * Be verbose with progressbar settings
1 parent 472da97 commit 0db176c

File tree

8 files changed

+576
-111
lines changed

8 files changed

+576
-111
lines changed

pymc/backends/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
RunType: TypeAlias = Run
8888
HAS_MCB = True
8989
except ImportError:
90-
TraceOrBackend = BaseTrace # type: ignore[misc]
90+
TraceOrBackend = BaseTrace # type: ignore[assignment, misc]
9191
RunType = type(None) # type: ignore[assignment, misc]
9292

9393

pymc/sampling/mcmc.py

+76-70
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@
3636
from arviz import InferenceData, dict_to_dataset
3737
from arviz.data.base import make_attrs
3838
from pytensor.graph.basic import Variable
39-
from rich.console import Console
40-
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
4139
from rich.theme import Theme
4240
from threadpoolctl import threadpool_limits
4341
from typing_extensions import Protocol
@@ -67,7 +65,8 @@
6765
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
6866
from pymc.step_methods.hmc import quadpotential
6967
from pymc.util import (
70-
CustomProgress,
68+
ProgressBarManager,
69+
ProgressBarType,
7170
RandomSeed,
7271
RandomState,
7372
_get_seeds_per_chain,
@@ -278,7 +277,7 @@ def _print_step_hierarchy(s: Step, level: int = 0) -> None:
278277
else:
279278
varnames = ", ".join(
280279
[
281-
get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name
280+
get_untransformed_name(v.name) if is_transformed_name(v.name) else v.name # type: ignore[misc]
282281
for v in s.vars
283282
]
284283
)
@@ -425,7 +424,7 @@ def sample(
425424
chains: int | None = None,
426425
cores: int | None = None,
427426
random_seed: RandomState = None,
428-
progressbar: bool = True,
427+
progressbar: bool | ProgressBarType = True,
429428
progressbar_theme: Theme | None = default_progress_theme,
430429
step=None,
431430
var_names: Sequence[str] | None = None,
@@ -457,7 +456,7 @@ def sample(
457456
chains: int | None = None,
458457
cores: int | None = None,
459458
random_seed: RandomState = None,
460-
progressbar: bool = True,
459+
progressbar: bool | ProgressBarType = True,
461460
progressbar_theme: Theme | None = default_progress_theme,
462461
step=None,
463462
var_names: Sequence[str] | None = None,
@@ -489,8 +488,8 @@ def sample(
489488
chains: int | None = None,
490489
cores: int | None = None,
491490
random_seed: RandomState = None,
492-
progressbar: bool = True,
493-
progressbar_theme: Theme | None = default_progress_theme,
491+
progressbar: bool | ProgressBarType = True,
492+
progressbar_theme: Theme | None = None,
494493
step=None,
495494
var_names: Sequence[str] | None = None,
496495
nuts_sampler: Literal["pymc", "nutpie", "numpyro", "blackjax"] = "pymc",
@@ -540,11 +539,18 @@ def sample(
540539
A ``TypeError`` will be raised if a legacy :py:class:`~numpy.random.RandomState` object is passed.
541540
We no longer support ``RandomState`` objects because their seeding mechanism does not allow
542541
easy spawning of new independent random streams that are needed by the step methods.
543-
progressbar : bool, optional default=True
544-
Whether or not to display a progress bar in the command line. The bar shows the percentage
545-
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
546-
time until completion ("expected time of arrival"; ETA).
547-
Only applicable to the pymc nuts sampler.
542+
progressbar: bool or ProgressType, optional
543+
How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask
544+
for one of the following:
545+
- "combined": A single progress bar that displays the total progress across all chains. Only timing
546+
information is shown.
547+
- "split": A separate progress bar for each chain. Only timing information is shown.
548+
- "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all
549+
chains. Aggregate sample statistics are also displayed.
550+
- "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain
551+
are also displayed.
552+
553+
If True, the default is "split+stats" is used.
548554
step : function or iterable of functions
549555
A step function or collection of functions. If there are variables without step methods,
550556
step methods for those variables will be assigned automatically. By default the NUTS step
@@ -710,6 +716,10 @@ def sample(
710716
if isinstance(trace, list):
711717
raise ValueError("Please use `var_names` keyword argument for partial traces.")
712718

719+
# progressbar might be a string, which is used by the ProgressManager in the pymc samplers. External samplers and
720+
# ADVI initialization expect just a bool.
721+
progress_bool = bool(progressbar)
722+
713723
model = modelcontext(model)
714724
if not model.free_RVs:
715725
raise SamplingError(
@@ -806,7 +816,7 @@ def joined_blas_limiter():
806816
initvals=initvals,
807817
model=model,
808818
var_names=var_names,
809-
progressbar=progressbar,
819+
progressbar=progress_bool,
810820
idata_kwargs=idata_kwargs,
811821
compute_convergence_checks=compute_convergence_checks,
812822
nuts_sampler_kwargs=nuts_sampler_kwargs,
@@ -825,7 +835,7 @@ def joined_blas_limiter():
825835
n_init=n_init,
826836
model=model,
827837
random_seed=random_seed_list,
828-
progressbar=progressbar,
838+
progressbar=progress_bool,
829839
jitter_max_retries=jitter_max_retries,
830840
tune=tune,
831841
initvals=initvals,
@@ -1139,34 +1149,44 @@ def _sample_many(
11391149
Step function
11401150
"""
11411151
initial_step_state = step.sampling_state
1142-
for i in range(chains):
1143-
step.sampling_state = initial_step_state
1144-
_sample(
1145-
draws=draws,
1146-
chain=i,
1147-
start=start[i],
1148-
step=step,
1149-
trace=traces[i],
1150-
rng=rngs[i],
1151-
callback=callback,
1152-
**kwargs,
1153-
)
1152+
progress_manager = ProgressBarManager(
1153+
step_method=step,
1154+
chains=chains,
1155+
draws=draws - kwargs.get("tune", 0),
1156+
tune=kwargs.get("tune", 0),
1157+
progressbar=kwargs.get("progressbar", True),
1158+
progressbar_theme=kwargs.get("progressbar_theme", default_progress_theme),
1159+
)
1160+
1161+
with progress_manager:
1162+
for i in range(chains):
1163+
step.sampling_state = initial_step_state
1164+
_sample(
1165+
draws=draws,
1166+
chain=i,
1167+
start=start[i],
1168+
step=step,
1169+
trace=traces[i],
1170+
rng=rngs[i],
1171+
callback=callback,
1172+
progress_manager=progress_manager,
1173+
**kwargs,
1174+
)
11541175
return
11551176

11561177

11571178
def _sample(
11581179
*,
11591180
chain: int,
1160-
progressbar: bool,
11611181
rng: np.random.Generator,
11621182
start: PointType,
11631183
draws: int,
11641184
step: Step,
11651185
trace: IBaseTrace,
11661186
tune: int,
11671187
model: Model | None = None,
1168-
progressbar_theme: Theme | None = default_progress_theme,
11691188
callback=None,
1189+
progress_manager: ProgressBarManager,
11701190
**kwargs,
11711191
) -> None:
11721192
"""Sample one chain (singleprocess).
@@ -1177,27 +1197,23 @@ def _sample(
11771197
----------
11781198
chain : int
11791199
Number of the chain that the samples will belong to.
1180-
progressbar : bool
1181-
Whether or not to display a progress bar in the command line. The bar shows the percentage
1182-
of completion, the sampling speed in samples per second (SPS), and the estimated remaining
1183-
time until completion ("expected time of arrival"; ETA).
1184-
random_seed : single random seed
1200+
random_seed : Generator
1201+
Single random seed
11851202
start : dict
11861203
Starting point in parameter space (or partial point)
11871204
draws : int
11881205
The number of samples to draw
1189-
step : function
1190-
Step function
1206+
step : Step
1207+
Step class instance used to generate samples.
11911208
trace
11921209
A chain backend to record draws and stats.
11931210
tune : int
11941211
Number of iterations to tune.
1195-
model : Model (optional if in ``with`` context)
1196-
progressbar_theme : Theme
1197-
Optional custom theme for the progress bar.
1212+
model : Model, optional
1213+
PyMC model. If None, the model is taken from the current context.
1214+
progress_manager: ProgressBarManager
1215+
Helper class used to handle progress bar styling and updates
11981216
"""
1199-
skip_first = kwargs.get("skip_first", 0)
1200-
12011217
sampling_gen = _iter_sample(
12021218
draws=draws,
12031219
step=step,
@@ -1209,32 +1225,19 @@ def _sample(
12091225
rng=rng,
12101226
callback=callback,
12111227
)
1212-
_pbar_data = {"chain": chain, "divergences": 0}
1213-
_desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
1214-
1215-
progress = CustomProgress(
1216-
"[progress.description]{task.description}",
1217-
BarColumn(),
1218-
"[progress.percentage]{task.percentage:>3.0f}%",
1219-
TimeRemainingColumn(),
1220-
TextColumn("/"),
1221-
TimeElapsedColumn(),
1222-
console=Console(theme=progressbar_theme),
1223-
disable=not progressbar,
1224-
)
1228+
try:
1229+
for it, stats in enumerate(sampling_gen):
1230+
progress_manager.update(
1231+
chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it > tune
1232+
)
12251233

1226-
with progress:
1227-
try:
1228-
task = progress.add_task(_desc.format(**_pbar_data), completed=0, total=draws)
1229-
for it, diverging in enumerate(sampling_gen):
1230-
if it >= skip_first and diverging:
1231-
_pbar_data["divergences"] += 1
1232-
progress.update(task, description=_desc.format(**_pbar_data), completed=it)
1233-
progress.update(
1234-
task, description=_desc.format(**_pbar_data), completed=draws, refresh=True
1234+
if not progress_manager.combined_progress or chain == progress_manager.chains - 1:
1235+
progress_manager.update(
1236+
chain_idx=chain, is_last=True, draw=it, stats=stats, tuning=False
12351237
)
1236-
except KeyboardInterrupt:
1237-
pass
1238+
1239+
except KeyboardInterrupt:
1240+
pass
12381241

12391242

12401243
def _iter_sample(
@@ -1248,7 +1251,7 @@ def _iter_sample(
12481251
rng: np.random.Generator,
12491252
model: Model | None = None,
12501253
callback: SamplingIteratorCallback | None = None,
1251-
) -> Iterator[bool]:
1254+
) -> Iterator[list[dict[str, Any]]]:
12521255
"""Sample one chain with a generator (singleprocess).
12531256
12541257
Parameters
@@ -1271,8 +1274,8 @@ def _iter_sample(
12711274
12721275
Yields
12731276
------
1274-
diverging : bool
1275-
Indicates if the draw is divergent. Only available with some samplers.
1277+
stats : list of dict
1278+
Dictionary of statistics returned by step sampler
12761279
"""
12771280
draws = int(draws)
12781281

@@ -1294,22 +1297,25 @@ def _iter_sample(
12941297
step.iter_count = 0
12951298
if i == tune:
12961299
step.stop_tuning()
1300+
12971301
point, stats = step.step(point)
12981302
trace.record(point, stats)
12991303
log_warning_stats(stats)
1300-
diverging = i > tune and len(stats) > 0 and (stats[0].get("diverging") is True)
1304+
13011305
if callback is not None:
13021306
callback(
13031307
trace=trace,
13041308
draw=Draw(chain, i == draws, i, i < tune, stats, point),
13051309
)
13061310

1307-
yield diverging
1311+
yield stats
1312+
13081313
except (KeyboardInterrupt, BaseException):
13091314
if isinstance(trace, ZarrChain):
13101315
trace.record_sampling_state(step=step)
13111316
trace.close()
13121317
raise
1318+
13131319
else:
13141320
if isinstance(trace, ZarrChain):
13151321
trace.record_sampling_state(step=step)

pymc/sampling/parallel.py

+12-35
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,14 @@
2727
import cloudpickle
2828
import numpy as np
2929

30-
from rich.console import Console
31-
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
3230
from rich.theme import Theme
3331
from threadpoolctl import threadpool_limits
3432

3533
from pymc.backends.zarr import ZarrChain
3634
from pymc.blocking import DictToArrayBijection
3735
from pymc.exceptions import SamplingError
3836
from pymc.util import (
39-
CustomProgress,
37+
ProgressBarManager,
4038
RandomGeneratorState,
4139
default_progress_theme,
4240
get_state_from_generator,
@@ -485,23 +483,14 @@ def __init__(
485483
self._max_active = cores
486484

487485
self._in_context = False
488-
489-
self._progress = CustomProgress(
490-
"[progress.description]{task.description}",
491-
BarColumn(),
492-
"[progress.percentage]{task.percentage:>3.0f}%",
493-
TimeRemainingColumn(),
494-
TextColumn("/"),
495-
TimeElapsedColumn(),
496-
console=Console(theme=progressbar_theme),
497-
disable=not progressbar,
486+
self._progress = ProgressBarManager(
487+
step_method=step_method,
488+
chains=chains,
489+
draws=draws,
490+
tune=tune,
491+
progressbar=progressbar,
492+
progressbar_theme=progressbar_theme,
498493
)
499-
self._show_progress = progressbar
500-
self._divergences = 0
501-
self._completed_draws = 0
502-
self._total_draws = chains * (draws + tune)
503-
self._desc = "Sampling {0._chains:d} chains, {0._divergences:,d} divergences"
504-
self._chains = chains
505494

506495
def _make_active(self):
507496
while self._inactive and len(self._active) < self._max_active:
@@ -516,32 +505,20 @@ def __iter__(self):
516505
raise ValueError("Use ParallelSampler as context manager.")
517506
self._make_active()
518507

519-
with self._progress as progress:
520-
task = progress.add_task(
521-
self._desc.format(self),
522-
completed=self._completed_draws,
523-
total=self._total_draws,
524-
)
525-
508+
with self._progress:
526509
while self._active:
527510
draw = ProcessAdapter.recv_draw(self._active)
528511
proc, is_last, draw, tuning, stats = draw
529-
self._completed_draws += 1
530-
if not tuning and stats and stats[0].get("diverging"):
531-
self._divergences += 1
532-
progress.update(
533-
task,
534-
completed=self._completed_draws,
535-
total=self._total_draws,
536-
description=self._desc.format(self),
512+
513+
self._progress.update(
514+
chain_idx=proc.chain, is_last=is_last, draw=draw, tuning=tuning, stats=stats
537515
)
538516

539517
if is_last:
540518
proc.join()
541519
self._active.remove(proc)
542520
self._finished.append(proc)
543521
self._make_active()
544-
progress.update(task, description=self._desc.format(self), refresh=True)
545522

546523
# We could also yield proc.shared_point_view directly,
547524
# and only call proc.write_next() after the yield returns.

0 commit comments

Comments
 (0)