Skip to content

Commit 5c52244

Browse files
authored
Some love to model comparison (#315)
* add .predict() to model comparison approximator * mc_calibration: pred_models should be first * changes to the model comparison simulator: - individual models can rely on a shared simulations (makes mixed batches much more useable) - mixed batches are sampled more efficiently (batched samples per model, rather than batching individual simulations) * fix mc_confusion_matrix plot * add simple model comparison notebook * flip pred_models and true_models in the docs
1 parent 1dbaedd commit 5c52244

File tree

6 files changed

+642
-35
lines changed

6 files changed

+642
-35
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ Check out some of our walk-through notebooks below. We are actively working on p
9999
4. [SBML model using an external simulator](examples/From_ABC_to_BayesFlow.ipynb)
100100
5. [Hyperparameter optimization](examples/Hyperparameter_Optimization.ipynb)
101101
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
102-
7. More coming soon...
102+
7. [Simple model comparison example (One-Sample T-Test)](examples/One_Sample_TTest.ipynb)
103+
8. More coming soon...
103104

104105
## Documentation \& Help
105106

bayesflow/approximators/model_comparison_approximator.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Mapping, Sequence
22

33
import keras
4+
import numpy as np
45
from keras.saving import (
56
deserialize_keras_object as deserialize,
67
register_keras_serializable as serializable,
@@ -198,3 +199,44 @@ def get_config(self):
198199
}
199200

200201
return base_config | config
202+
203+
def predict(
204+
self,
205+
*,
206+
conditions: dict[str, np.ndarray],
207+
logits: bool = False,
208+
**kwargs,
209+
) -> np.ndarray:
210+
conditions = self.adapter(conditions, strict=False, stage="inference", **kwargs)
211+
conditions = keras.tree.map_structure(keras.ops.convert_to_tensor, conditions)
212+
213+
output = self._predict(**conditions, **kwargs)
214+
215+
if not logits:
216+
output = keras.ops.softmax(output)
217+
218+
output = keras.ops.convert_to_numpy(output)
219+
220+
return output
221+
222+
def _predict(self, classifier_conditions: Tensor = None, summary_variables: Tensor = None, **kwargs) -> Tensor:
223+
if self.summary_network is None:
224+
if summary_variables is not None:
225+
raise ValueError("Cannot use summary variables without a summary network.")
226+
else:
227+
if summary_variables is None:
228+
raise ValueError("Summary variables are required when a summary network is present")
229+
230+
summary_outputs = self.summary_network(
231+
summary_variables, **filter_kwargs(kwargs, self.summary_network.call)
232+
)
233+
234+
if classifier_conditions is None:
235+
classifier_conditions = summary_outputs
236+
else:
237+
classifier_conditions = keras.ops.concatenate([classifier_conditions, summary_outputs], axis=1)
238+
239+
output = self.classifier_network(classifier_conditions)
240+
output = self.logits_projector(output)
241+
242+
return output

bayesflow/diagnostics/plots/mc_calibration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ def mc_calibration(
3434
3535
Parameters
3636
----------
37-
true_models : np.ndarray of shape (num_data_sets, num_models)
38-
The one-hot-encoded true model indices per data set.
3937
pred_models : np.ndarray of shape (num_data_sets, num_models)
4038
The predicted posterior model probabilities (PMPs) per data set.
39+
true_models : np.ndarray of shape (num_data_sets, num_models)
40+
The one-hot-encoded true model indices per data set.
4141
model_names : list or None, optional, default: None
4242
The model names for nice plot titles. Inferred if None.
4343
num_bins : int, optional, default: 10

bayesflow/diagnostics/plots/mc_confusion_matrix.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414

1515
def mc_confusion_matrix(
16-
true_models: dict[str, np.ndarray] | np.ndarray,
1716
pred_models: dict[str, np.ndarray] | np.ndarray,
17+
true_models: dict[str, np.ndarray] | np.ndarray,
1818
model_names: Sequence[str] = None,
1919
fig_size: tuple = (5, 5),
2020
label_fontsize: int = 16,
@@ -23,18 +23,18 @@ def mc_confusion_matrix(
2323
tick_fontsize: int = 12,
2424
xtick_rotation: int = None,
2525
ytick_rotation: int = None,
26-
normalize: bool = True,
26+
normalize: str = None,
2727
cmap: matplotlib.colors.Colormap | str = None,
2828
title: bool = True,
2929
) -> plt.Figure:
3030
"""Plots a confusion matrix for validating a neural network trained for Bayesian model comparison.
3131
3232
Parameters
3333
----------
34-
true_models : np.ndarray of shape (num_data_sets, num_models)
35-
The one-hot-encoded true model indices per data set.
3634
pred_models : np.ndarray of shape (num_data_sets, num_models)
3735
The predicted posterior model probabilities (PMPs) per data set.
36+
true_models : np.ndarray of shape (num_data_sets, num_models)
37+
The one-hot-encoded true model indices per data set.
3838
model_names : list or None, optional, default: None
3939
The model names for nice plot titles. Inferred if None.
4040
fig_size : tuple or None, optional, default: (5, 5)
@@ -51,9 +51,11 @@ def mc_confusion_matrix(
5151
Rotation of x-axis tick labels (helps with long model names).
5252
ytick_rotation: int, optional, default: None
5353
Rotation of y-axis tick labels (helps with long model names).
54-
normalize : bool, optional, default: True
55-
A flag for normalization of the confusion matrix.
56-
If True, each row of the confusion matrix is normalized to sum to 1.
54+
normalize : {'true', 'pred', 'all'}, default=None
55+
Passed to sklearn.metrics.confusion_matrix.
56+
Normalizes confusion matrix over the true (rows), predicted (columns)
57+
conditions or all the population. If None, confusion matrix will not be
58+
normalized.
5759
cmap : matplotlib.colors.Colormap or str, optional, default: None
5860
Colormap to be used for the cells. If a str, it should be the name of a registered colormap,
5961
e.g., 'viridis'. Default colormap matches the BayesFlow defaults by ranging from white to red.
@@ -77,29 +79,26 @@ def mc_confusion_matrix(
7779
pred_models = ops.argmax(pred_models, axis=1)
7880

7981
# Compute confusion matrix
80-
cm = confusion_matrix(true_models, pred_models)
81-
82-
# if normalize:
83-
# # Sum along rows and keep dimensions for broadcasting
84-
# cm_sum = ops.sum(cm, axis=1, keepdims=True)
85-
#
86-
# # Broadcast division for normalization
87-
# cm_normalized = cm / cm_sum
82+
cm = confusion_matrix(true_models, pred_models, normalize=normalize)
8883

8984
# Initialize figure
9085
fig, ax = make_figure(1, 1, figsize=fig_size)
86+
ax = ax[0]
9187
im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
9288
cbar = ax.figure.colorbar(im, ax=ax, shrink=0.75)
9389

9490
cbar.ax.tick_params(labelsize=value_fontsize)
9591

96-
ax.set(xticks=ops.arange(cm.shape[1]), yticks=ops.arange(cm.shape[0]))
92+
ax.set_xticks(range(cm.shape[0]))
9793
ax.set_xticklabels(model_names, fontsize=tick_fontsize)
9894
if xtick_rotation:
9995
plt.xticks(rotation=xtick_rotation, ha="right")
96+
97+
ax.set_yticks(range(cm.shape[1]))
10098
ax.set_yticklabels(model_names, fontsize=tick_fontsize)
10199
if ytick_rotation:
102100
plt.yticks(rotation=ytick_rotation)
101+
103102
ax.set_xlabel("Predicted model", fontsize=label_fontsize)
104103
ax.set_ylabel("True model", fontsize=label_fontsize)
105104

bayesflow/simulators/model_comparison_simulator.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
import numpy as np
33

44
from bayesflow.types import Shape
5-
from bayesflow.utils import tree_stack
5+
from bayesflow.utils import tree_concatenate
66
from bayesflow.utils.decorators import allow_batch_size
77

88
from bayesflow.utils import numpy_utils as npu
99

10+
from types import FunctionType
11+
1012
from .simulator import Simulator
13+
from .lambda_simulator import LambdaSimulator
1114

1215

1316
class ModelComparisonSimulator(Simulator):
@@ -18,10 +21,15 @@ def __init__(
1821
simulators: Sequence[Simulator],
1922
p: Sequence[float] = None,
2023
logits: Sequence[float] = None,
21-
use_mixed_batches: bool = False,
24+
use_mixed_batches: bool = True,
25+
shared_simulator: Simulator | FunctionType = None,
2226
):
2327
self.simulators = simulators
2428

29+
if isinstance(shared_simulator, FunctionType):
30+
shared_simulator = LambdaSimulator(shared_simulator, is_batched=True)
31+
self.shared_simulator = shared_simulator
32+
2533
match logits, p:
2634
case (None, None):
2735
logits = [0.0] * len(simulators)
@@ -43,30 +51,34 @@ def __init__(
4351

4452
@allow_batch_size
4553
def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]:
54+
data = {}
55+
if self.shared_simulator:
56+
data |= self.shared_simulator.sample(batch_shape, **kwargs)
57+
4658
if not self.use_mixed_batches:
4759
# draw one model index for the whole batch (faster)
4860
model_index = np.random.choice(len(self.simulators), p=npu.softmax(self.logits))
4961

5062
simulator = self.simulators[model_index]
51-
data = simulator.sample(batch_shape)
63+
data = simulator.sample(batch_shape, **(kwargs | data))
5264

5365
model_indices = np.full(batch_shape, model_index, dtype="int32")
66+
model_indices = npu.one_hot(model_indices, len(self.simulators))
5467
else:
55-
# draw a model index for each sample in the batch (slower)
56-
model_indices = np.random.choice(len(self.simulators), p=npu.softmax(self.logits), size=batch_shape)
57-
58-
data = np.empty(batch_shape, dtype="object")
59-
60-
for index in np.ndindex(batch_shape):
61-
simulator = self.simulators[int(model_indices[index])]
62-
data[index] = simulator.sample(())
68+
# generate data randomly from each model (slower)
69+
model_counts = np.random.multinomial(n=batch_shape[0], pvals=npu.softmax(self.logits))
6370

64-
data = data.flatten().tolist()
65-
data = tree_stack(data, axis=0, numpy=True)
71+
sims = []
72+
for n, simulator in zip(model_counts, self.simulators):
73+
if n == 0:
74+
continue
75+
sim = simulator.sample(n, **(kwargs | data))
76+
sims.append(sim)
6677

67-
# restore batch shape
68-
data = {key: np.reshape(value, batch_shape + np.shape(value)[1:]) for key, value in data.items()}
78+
sims = tree_concatenate(sims, numpy=True)
79+
data |= sims
6980

70-
model_indices = npu.one_hot(model_indices, len(self.simulators))
81+
model_indices = np.eye(len(self.simulators), dtype="int32")
82+
model_indices = np.repeat(model_indices, model_counts, axis=0)
7183

7284
return data | {"model_indices": model_indices}

0 commit comments

Comments
 (0)