36
36
from arviz import InferenceData , dict_to_dataset
37
37
from arviz .data .base import make_attrs
38
38
from pytensor .graph .basic import Variable
39
- from rich .console import Console
40
- from rich .progress import BarColumn , TextColumn , TimeElapsedColumn , TimeRemainingColumn
41
39
from rich .theme import Theme
42
40
from threadpoolctl import threadpool_limits
43
41
from typing_extensions import Protocol
67
65
from pymc .step_methods .arraystep import BlockedStep , PopulationArrayStepShared
68
66
from pymc .step_methods .hmc import quadpotential
69
67
from pymc .util import (
70
- CustomProgress ,
68
+ ProgressBarManager ,
69
+ ProgressBarType ,
71
70
RandomSeed ,
72
71
RandomState ,
73
72
_get_seeds_per_chain ,
@@ -278,7 +277,7 @@ def _print_step_hierarchy(s: Step, level: int = 0) -> None:
278
277
else :
279
278
varnames = ", " .join (
280
279
[
281
- get_untransformed_name (v .name ) if is_transformed_name (v .name ) else v .name
280
+ get_untransformed_name (v .name ) if is_transformed_name (v .name ) else v .name # type: ignore[misc]
282
281
for v in s .vars
283
282
]
284
283
)
@@ -425,7 +424,7 @@ def sample(
425
424
chains : int | None = None ,
426
425
cores : int | None = None ,
427
426
random_seed : RandomState = None ,
428
- progressbar : bool = True ,
427
+ progressbar : bool | ProgressBarType = True ,
429
428
progressbar_theme : Theme | None = default_progress_theme ,
430
429
step = None ,
431
430
var_names : Sequence [str ] | None = None ,
@@ -457,7 +456,7 @@ def sample(
457
456
chains : int | None = None ,
458
457
cores : int | None = None ,
459
458
random_seed : RandomState = None ,
460
- progressbar : bool = True ,
459
+ progressbar : bool | ProgressBarType = True ,
461
460
progressbar_theme : Theme | None = default_progress_theme ,
462
461
step = None ,
463
462
var_names : Sequence [str ] | None = None ,
@@ -489,8 +488,8 @@ def sample(
489
488
chains : int | None = None ,
490
489
cores : int | None = None ,
491
490
random_seed : RandomState = None ,
492
- progressbar : bool = True ,
493
- progressbar_theme : Theme | None = default_progress_theme ,
491
+ progressbar : bool | ProgressBarType = True ,
492
+ progressbar_theme : Theme | None = None ,
494
493
step = None ,
495
494
var_names : Sequence [str ] | None = None ,
496
495
nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
@@ -540,11 +539,18 @@ def sample(
540
539
A ``TypeError`` will be raised if a legacy :py:class:`~numpy.random.RandomState` object is passed.
541
540
We no longer support ``RandomState`` objects because their seeding mechanism does not allow
542
541
easy spawning of new independent random streams that are needed by the step methods.
543
- progressbar : bool, optional default=True
544
- Whether or not to display a progress bar in the command line. The bar shows the percentage
545
- of completion, the sampling speed in samples per second (SPS), and the estimated remaining
546
- time until completion ("expected time of arrival"; ETA).
547
- Only applicable to the pymc nuts sampler.
542
+ progressbar: bool or ProgressType, optional
543
+ How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask
544
+ for one of the following:
545
+ - "combined": A single progress bar that displays the total progress across all chains. Only timing
546
+ information is shown.
547
+ - "split": A separate progress bar for each chain. Only timing information is shown.
548
+ - "combined+stats" or "stats+combined": A single progress bar displaying the total progress across all
549
+ chains. Aggregate sample statistics are also displayed.
550
+ - "split+stats" or "stats+split": A separate progress bar for each chain. Sample statistics for each chain
551
+ are also displayed.
552
+
553
+ If True, the default is "split+stats" is used.
548
554
step : function or iterable of functions
549
555
A step function or collection of functions. If there are variables without step methods,
550
556
step methods for those variables will be assigned automatically. By default the NUTS step
@@ -710,6 +716,10 @@ def sample(
710
716
if isinstance (trace , list ):
711
717
raise ValueError ("Please use `var_names` keyword argument for partial traces." )
712
718
719
+ # progressbar might be a string, which is used by the ProgressManager in the pymc samplers. External samplers and
720
+ # ADVI initialization expect just a bool.
721
+ progress_bool = bool (progressbar )
722
+
713
723
model = modelcontext (model )
714
724
if not model .free_RVs :
715
725
raise SamplingError (
@@ -806,7 +816,7 @@ def joined_blas_limiter():
806
816
initvals = initvals ,
807
817
model = model ,
808
818
var_names = var_names ,
809
- progressbar = progressbar ,
819
+ progressbar = progress_bool ,
810
820
idata_kwargs = idata_kwargs ,
811
821
compute_convergence_checks = compute_convergence_checks ,
812
822
nuts_sampler_kwargs = nuts_sampler_kwargs ,
@@ -825,7 +835,7 @@ def joined_blas_limiter():
825
835
n_init = n_init ,
826
836
model = model ,
827
837
random_seed = random_seed_list ,
828
- progressbar = progressbar ,
838
+ progressbar = progress_bool ,
829
839
jitter_max_retries = jitter_max_retries ,
830
840
tune = tune ,
831
841
initvals = initvals ,
@@ -1139,34 +1149,44 @@ def _sample_many(
1139
1149
Step function
1140
1150
"""
1141
1151
initial_step_state = step .sampling_state
1142
- for i in range (chains ):
1143
- step .sampling_state = initial_step_state
1144
- _sample (
1145
- draws = draws ,
1146
- chain = i ,
1147
- start = start [i ],
1148
- step = step ,
1149
- trace = traces [i ],
1150
- rng = rngs [i ],
1151
- callback = callback ,
1152
- ** kwargs ,
1153
- )
1152
+ progress_manager = ProgressBarManager (
1153
+ step_method = step ,
1154
+ chains = chains ,
1155
+ draws = draws - kwargs .get ("tune" , 0 ),
1156
+ tune = kwargs .get ("tune" , 0 ),
1157
+ progressbar = kwargs .get ("progressbar" , True ),
1158
+ progressbar_theme = kwargs .get ("progressbar_theme" , default_progress_theme ),
1159
+ )
1160
+
1161
+ with progress_manager :
1162
+ for i in range (chains ):
1163
+ step .sampling_state = initial_step_state
1164
+ _sample (
1165
+ draws = draws ,
1166
+ chain = i ,
1167
+ start = start [i ],
1168
+ step = step ,
1169
+ trace = traces [i ],
1170
+ rng = rngs [i ],
1171
+ callback = callback ,
1172
+ progress_manager = progress_manager ,
1173
+ ** kwargs ,
1174
+ )
1154
1175
return
1155
1176
1156
1177
1157
1178
def _sample (
1158
1179
* ,
1159
1180
chain : int ,
1160
- progressbar : bool ,
1161
1181
rng : np .random .Generator ,
1162
1182
start : PointType ,
1163
1183
draws : int ,
1164
1184
step : Step ,
1165
1185
trace : IBaseTrace ,
1166
1186
tune : int ,
1167
1187
model : Model | None = None ,
1168
- progressbar_theme : Theme | None = default_progress_theme ,
1169
1188
callback = None ,
1189
+ progress_manager : ProgressBarManager ,
1170
1190
** kwargs ,
1171
1191
) -> None :
1172
1192
"""Sample one chain (singleprocess).
@@ -1177,27 +1197,23 @@ def _sample(
1177
1197
----------
1178
1198
chain : int
1179
1199
Number of the chain that the samples will belong to.
1180
- progressbar : bool
1181
- Whether or not to display a progress bar in the command line. The bar shows the percentage
1182
- of completion, the sampling speed in samples per second (SPS), and the estimated remaining
1183
- time until completion ("expected time of arrival"; ETA).
1184
- random_seed : single random seed
1200
+ random_seed : Generator
1201
+ Single random seed
1185
1202
start : dict
1186
1203
Starting point in parameter space (or partial point)
1187
1204
draws : int
1188
1205
The number of samples to draw
1189
- step : function
1190
- Step function
1206
+ step : Step
1207
+ Step class instance used to generate samples.
1191
1208
trace
1192
1209
A chain backend to record draws and stats.
1193
1210
tune : int
1194
1211
Number of iterations to tune.
1195
- model : Model (optional if in ``with`` context)
1196
- progressbar_theme : Theme
1197
- Optional custom theme for the progress bar.
1212
+ model : Model, optional
1213
+ PyMC model. If None, the model is taken from the current context.
1214
+ progress_manager: ProgressBarManager
1215
+ Helper class used to handle progress bar styling and updates
1198
1216
"""
1199
- skip_first = kwargs .get ("skip_first" , 0 )
1200
-
1201
1217
sampling_gen = _iter_sample (
1202
1218
draws = draws ,
1203
1219
step = step ,
@@ -1209,32 +1225,19 @@ def _sample(
1209
1225
rng = rng ,
1210
1226
callback = callback ,
1211
1227
)
1212
- _pbar_data = {"chain" : chain , "divergences" : 0 }
1213
- _desc = "Sampling chain {chain:d}, {divergences:,d} divergences"
1214
-
1215
- progress = CustomProgress (
1216
- "[progress.description]{task.description}" ,
1217
- BarColumn (),
1218
- "[progress.percentage]{task.percentage:>3.0f}%" ,
1219
- TimeRemainingColumn (),
1220
- TextColumn ("/" ),
1221
- TimeElapsedColumn (),
1222
- console = Console (theme = progressbar_theme ),
1223
- disable = not progressbar ,
1224
- )
1228
+ try :
1229
+ for it , stats in enumerate (sampling_gen ):
1230
+ progress_manager .update (
1231
+ chain_idx = chain , is_last = False , draw = it , stats = stats , tuning = it > tune
1232
+ )
1225
1233
1226
- with progress :
1227
- try :
1228
- task = progress .add_task (_desc .format (** _pbar_data ), completed = 0 , total = draws )
1229
- for it , diverging in enumerate (sampling_gen ):
1230
- if it >= skip_first and diverging :
1231
- _pbar_data ["divergences" ] += 1
1232
- progress .update (task , description = _desc .format (** _pbar_data ), completed = it )
1233
- progress .update (
1234
- task , description = _desc .format (** _pbar_data ), completed = draws , refresh = True
1234
+ if not progress_manager .combined_progress or chain == progress_manager .chains - 1 :
1235
+ progress_manager .update (
1236
+ chain_idx = chain , is_last = True , draw = it , stats = stats , tuning = False
1235
1237
)
1236
- except KeyboardInterrupt :
1237
- pass
1238
+
1239
+ except KeyboardInterrupt :
1240
+ pass
1238
1241
1239
1242
1240
1243
def _iter_sample (
@@ -1248,7 +1251,7 @@ def _iter_sample(
1248
1251
rng : np .random .Generator ,
1249
1252
model : Model | None = None ,
1250
1253
callback : SamplingIteratorCallback | None = None ,
1251
- ) -> Iterator [bool ]:
1254
+ ) -> Iterator [list [ dict [ str , Any ]] ]:
1252
1255
"""Sample one chain with a generator (singleprocess).
1253
1256
1254
1257
Parameters
@@ -1271,8 +1274,8 @@ def _iter_sample(
1271
1274
1272
1275
Yields
1273
1276
------
1274
- diverging : bool
1275
- Indicates if the draw is divergent. Only available with some samplers.
1277
+ stats : list of dict
1278
+ Dictionary of statistics returned by step sampler
1276
1279
"""
1277
1280
draws = int (draws )
1278
1281
@@ -1294,22 +1297,25 @@ def _iter_sample(
1294
1297
step .iter_count = 0
1295
1298
if i == tune :
1296
1299
step .stop_tuning ()
1300
+
1297
1301
point , stats = step .step (point )
1298
1302
trace .record (point , stats )
1299
1303
log_warning_stats (stats )
1300
- diverging = i > tune and len ( stats ) > 0 and ( stats [ 0 ]. get ( "diverging" ) is True )
1304
+
1301
1305
if callback is not None :
1302
1306
callback (
1303
1307
trace = trace ,
1304
1308
draw = Draw (chain , i == draws , i , i < tune , stats , point ),
1305
1309
)
1306
1310
1307
- yield diverging
1311
+ yield stats
1312
+
1308
1313
except (KeyboardInterrupt , BaseException ):
1309
1314
if isinstance (trace , ZarrChain ):
1310
1315
trace .record_sampling_state (step = step )
1311
1316
trace .close ()
1312
1317
raise
1318
+
1313
1319
else :
1314
1320
if isinstance (trace , ZarrChain ):
1315
1321
trace .record_sampling_state (step = step )
0 commit comments